]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Hexagon add support for f16/f32 flash attention, scale, set-rows and improve f16...
authorMax Krasnyansky <redacted>
Wed, 7 Jan 2026 01:38:29 +0000 (17:38 -0800)
committerGeorgi Gerganov <redacted>
Wed, 14 Jan 2026 07:11:59 +0000 (09:11 +0200)
* hexagon: improve fp16 matmul and add fp32/fp16 flash-attention

* hexagon: add support for set-rows fp32 -> fp16 with i32/i64 row-idx

* hexagon: add support for SCALE fp32

* hexagon: replace scalar fp32 -> fp16 copy with HVX

* hexagon: optimize flash_atten_ext with aligned VTCM buffers and DMA

- Implements double-buffered DMA prefetching for K, V, and Mask tensors.
- Ensures K and V rows in VTCM are padded to 128 bytes to support aligned HVX operations.
- Correctly synchronizes DMA transfers to prevent race conditions.
- Uses `FLASH_ATTN_BLOCK_SIZE` of 128 for efficient chunking.

* hexagon: use aligned mad_f16

* hexagon: flash_atten more aligned ops

* hexagon: optimize scale_f32 hvx helpers

* hexagon: unroll fa loops

* hexagon: remove unused set-rows log

* hexagon: flash_attn_ext add support for DMAing Q

- Update `op_flash_attn_ext` to include Q row size in scratchpad allocation.
- Pad Q row size to 128 bytes for alignment.
- Implement DMA transfer for Q tensor in `flash_attn_ext_f16_thread`.
- Update dot product computations to use VTCM-buffered Q data.

* hexagon: fix handling of NANs hvx dotproducts

* hexagon: cleanup spad allocation in flash-atten

* hexagon: improve fp16/fp32 matmul

- Introduced `vec_dot_f16_f16` and `vec_dot_f16_f16_rx2` kernels using efficient HVX dot product intrinsics.
- Added `quantize_fp32_f16` to copy/convert weights from DDR to VTCM
- Updated `op_matmul` to use the optimized path when VTCM capacity allows and broadcasting requirements are compatible.
- Implemented fallback logic to the original implementation for complex broadcasting scenarios.

* hexagon: fix HVX_ARCH check

* hexagon: matmul cleanup and fp16 fixes

Use aligned vec_dot_f16 for 2d matmuls and unaligned version for 4d.

* hexagon: fix fp16 x fp16 matmuls and some minor refactoring

* hexagon: add support for GET_ROWS f32 -> f32

Also optimize SET_ROWS threading a bit when we have just a few rows to process.

* hexagon: optimize set-rows threading

* hexagon: update adb/run-bench.sh to properly support experimental and verbose options

* hexagon: flash_atten use aligned vectors for dot products

14 files changed:
ggml/src/ggml-hexagon/ggml-hexagon.cpp
ggml/src/ggml-hexagon/htp/CMakeLists.txt
ggml/src/ggml-hexagon/htp/flash-attn-ops.c [new file with mode: 0644]
ggml/src/ggml-hexagon/htp/get-rows-ops.c [new file with mode: 0644]
ggml/src/ggml-hexagon/htp/htp-ctx.h
ggml/src/ggml-hexagon/htp/htp-msg.h
ggml/src/ggml-hexagon/htp/htp-ops.h
ggml/src/ggml-hexagon/htp/hvx-utils.c
ggml/src/ggml-hexagon/htp/hvx-utils.h
ggml/src/ggml-hexagon/htp/main.c
ggml/src/ggml-hexagon/htp/matmul-ops.c
ggml/src/ggml-hexagon/htp/set-rows-ops.c [new file with mode: 0644]
ggml/src/ggml-hexagon/htp/softmax-ops.c
ggml/src/ggml-hexagon/htp/unary-ops.c

index 13b96d61f85b7a37a44b9208564d3e596163976b..365a24b4965c87f1d078b92652003a830d008d0e 100644 (file)
@@ -1773,6 +1773,37 @@ static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_
     return true;
 }
 
+static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+    const struct ggml_tensor * src2 = op->src[2];
+    const struct ggml_tensor * src3 = op->src[3];
+    const struct ggml_tensor * src4 = op->src[4];
+    const struct ggml_tensor * dst  = op;
+
+    // Check for F16 support only as requested
+    if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {
+        return false;
+    }
+
+    if (src3 && src3->type != GGML_TYPE_F16) {  // mask
+        return false;
+    }
+
+    if (src4 && src4->type != GGML_TYPE_F32) {  // sinks
+        return false;
+    }
+
+    // For now we support F32 or F16 output as htp backend often converts output on the fly if needed,
+    // but the op implementation writes to F16 or F32.
+    // Let's assume dst can be F32 or F16.
+    if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
+        return false;
+    }
+
+    return opt_experimental;
+}
+
 static bool hex_supported_src0_type(ggml_type t) {
     return t == GGML_TYPE_F32;
 }
@@ -1815,12 +1846,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
 
-    if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
+    if (dst->type != GGML_TYPE_F32) {
         return false;
     }
 
-    // TODO: add support for non-cont tensors
-    if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+    if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
         return false;
     }
 
@@ -1836,7 +1866,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
                 return false;  // typically the lm-head which would be too large for VTCM
             }
 
-            // if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false;
             if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
                 return false;
             }
@@ -1885,21 +1914,10 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
             }
             break;
 
-        case GGML_TYPE_F16:
-            if (!opt_experimental) {
-                return false;
-            }
-            break;
-
         default:
             return false;
     }
 
-    // TODO: add support for non-cont tensors
-    if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
-        return false;
-    }
-
     return true;
 }
 
@@ -2060,6 +2078,46 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
     return true;
 }
 
+static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0]; // values
+    const struct ggml_tensor * src1 = op->src[1]; // indices
+    const struct ggml_tensor * dst  = op;
+
+    if (src0->type != GGML_TYPE_F32) {
+        return false;
+    }
+
+    if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
+        return false;
+    }
+
+    if (dst->type != GGML_TYPE_F16) {
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0]; // values
+    const struct ggml_tensor * src1 = op->src[1]; // indices
+    const struct ggml_tensor * dst  = op;
+
+    if (src0->type != GGML_TYPE_F32) {
+        return false;
+    }
+
+    if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
+        return false;
+    }
+
+    if (dst->type != GGML_TYPE_F32) {
+        return false;
+    }
+
+    return true;
+}
+
 static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
     const int32_t * op_params = &op->op_params[0];
 
@@ -2154,6 +2212,11 @@ static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_t
     d->offset = (uint8_t *) t->data - buf->base;
     d->size   = ggml_nbytes(t);
 
+    if (!d->size) {
+        // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty
+        d->size = 64;
+    }
+
     switch (type) {
         case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:
             // Flush CPU
@@ -2239,6 +2302,17 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
     return n_bufs;
 }
 
+static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    req->op = HTP_OP_GET_ROWS;
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
 template <bool _is_src0_constant>
 static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
     switch (t->op) {
@@ -2266,6 +2340,17 @@ static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer *
     return n_bufs;
 }
 
+static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    req->op = HTP_OP_SET_ROWS;
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
 static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
     memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
 
@@ -2277,6 +2362,11 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
             supported = true;
             break;
 
+        case GGML_OP_SCALE:
+            req->op   = HTP_OP_SCALE;
+            supported = true;
+            break;
+
         case GGML_OP_UNARY:
             if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
                 req->op   = HTP_OP_UNARY_SILU;
@@ -2331,6 +2421,21 @@ static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs
     return n_bufs;
 }
 
+static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+    req->op = HTP_OP_FLASH_ATTN_EXT;
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
 static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
     auto sess = static_cast<ggml_hexagon_session *>(backend->context);
     return sess->name.c_str();
@@ -2417,6 +2522,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
                 break;
             case GGML_OP_RMS_NORM:
+            case GGML_OP_SCALE:
                 ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
                 break;
             case GGML_OP_UNARY:
@@ -2439,6 +2545,18 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);
                 break;
 
+            case GGML_OP_FLASH_ATTN_EXT:
+                ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);
+                break;
+
+            case GGML_OP_SET_ROWS:
+                ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);
+                break;
+
+            case GGML_OP_GET_ROWS:
+                ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
+                break;
+
             default:
                 GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
         }
@@ -2778,6 +2896,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
             break;
 
         case GGML_OP_RMS_NORM:
+        case GGML_OP_SCALE:
             supp = ggml_hexagon_supported_unary(sess, op);
             break;
 
@@ -2805,6 +2924,18 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
             supp = ggml_hexagon_supported_rope(sess, op);
             break;
 
+        case GGML_OP_FLASH_ATTN_EXT:
+            supp = ggml_hexagon_supported_flash_attn_ext(sess, op);
+            break;
+
+        case GGML_OP_SET_ROWS:
+            supp = ggml_hexagon_supported_set_rows(sess, op);
+            break;
+
+        case GGML_OP_GET_ROWS:
+            supp = ggml_hexagon_supported_get_rows(sess, op);
+            break;
+
         default:
             break;
     }
index 2cf8aaa42a8254b7ca20bdd769f2e43693ec2082..6a34a215fa4937aebfa6eed4340a369b7f65e3e5 100644 (file)
@@ -28,6 +28,9 @@ add_library(${HTP_LIB} SHARED
     softmax-ops.c
     act-ops.c
     rope-ops.c
+    flash-attn-ops.c
+    set-rows-ops.c
+    get-rows-ops.c
 )
 
 target_compile_definitions(${HTP_LIB} PRIVATE
diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c
new file mode 100644 (file)
index 0000000..04a7b84
--- /dev/null
@@ -0,0 +1,566 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+#    define FARF_HIGH 1
+#endif
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+// Dot product of FP32 and FP16 vectors, accumulating to float
+static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
+    const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
+    const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
+
+    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+    uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+    const HVX_Vector zero = Q6_V_vsplat_R(0);
+    HVX_Vector       rsum = Q6_V_vsplat_R(0);
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (i = 0; i < nvec; i++) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements
+        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements
+        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+
+        // Load x (fp16)
+        HVX_Vector x_hf  = vx[i];
+
+        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+    }
+
+    if (nloe) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements
+        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements
+        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+
+        // Load x (fp16)
+        HVX_Vector x_hf  = vx[i];
+
+        // Zero-out unused elements
+        // Note that we need to clear both x and y because they may contain NANs
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+        x_hf = Q6_V_vand_QV(bmask, x_hf);
+        y_hf = Q6_V_vand_QV(bmask, y_hf);
+
+        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+    }
+
+    rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
+    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+
+    hvx_vec_store_u(r, 4, rsum);
+}
+
+// Dot product of two F16 vectors, accumulating to float
+static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
+    const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
+    const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
+
+    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+    uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+    const HVX_Vector zero = Q6_V_vsplat_R(0);
+    HVX_Vector       rsum = Q6_V_vsplat_R(0);
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (i = 0; i < nvec; i++) {
+        HVX_Vector y_hf = vy[i];
+        HVX_Vector x_hf = vx[i];
+
+        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
+    }
+
+    if (nloe) {
+        HVX_Vector y_hf = vy[i];
+
+        // Load x (fp16) and zero-out unused elements
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+        HVX_Vector      x_hf = Q6_V_vand_QV(bmask, vx[i]);
+
+        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
+    }
+
+    rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
+    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+    hvx_vec_store_u(r, 4, rsum);
+}
+
+// MAD: y (F32) += x (F16) * v (float)
+static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
+    const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
+    HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
+
+    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+    uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+    HVX_Vector S = hvx_vec_splat_fp16(s);
+
+    uint32_t i = 0;
+    #pragma unroll(4)
+    for (i = 0; i < nvec; ++i) {
+        // Multiply x * s -> pair of F32 vectors
+        HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
+        ptr_y[i*2]   = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
+        ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
+    }
+
+    if (nloe) {
+        HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
+
+        HVX_Vector xs = Q6_V_lo_W(xs_p);
+        i = 2 * i; // index for ptr_y
+
+        if (nloe >= 32) {
+            ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
+            nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
+        }
+
+        if (nloe) {
+            HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
+            hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
+        }
+    }
+}
+
+#define FLASH_ATTN_BLOCK_SIZE 128
+
+static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
+    const struct htp_tensor * q = &octx->src0;
+    const struct htp_tensor * k = &octx->src1;
+    const struct htp_tensor * v = &octx->src2;
+    const struct htp_tensor * mask  = (octx->src3.data) ? &octx->src3 : NULL;
+    const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
+    struct htp_tensor * dst = &octx->dst;
+
+    const uint32_t neq0 = q->ne[0];
+    const uint32_t neq1 = q->ne[1];
+    const uint32_t neq2 = q->ne[2];
+    const uint32_t neq3 = q->ne[3];
+
+    const uint32_t nek0 = k->ne[0];
+    const uint32_t nek1 = k->ne[1];
+    const uint32_t nek2 = k->ne[2];
+    const uint32_t nek3 = k->ne[3];
+
+    const uint32_t nev0 = v->ne[0];
+    const uint32_t nev1 = v->ne[1];
+    const uint32_t nev2 = v->ne[2];
+    const uint32_t nev3 = v->ne[3];
+
+    const uint32_t nbq1 = q->nb[1];
+    const uint32_t nbq2 = q->nb[2];
+    const uint32_t nbq3 = q->nb[3];
+
+    const uint32_t nbk1 = k->nb[1];
+    const uint32_t nbk2 = k->nb[2];
+    const uint32_t nbk3 = k->nb[3];
+
+    const uint32_t nbv1 = v->nb[1];
+    const uint32_t nbv2 = v->nb[2];
+    const uint32_t nbv3 = v->nb[3];
+
+    const uint32_t ne1 = dst->ne[1];
+    const uint32_t ne2 = dst->ne[2];
+    const uint32_t ne3 = dst->ne[3];
+
+    const uint32_t nb1 = dst->nb[1];
+    const uint32_t nb2 = dst->nb[2];
+    const uint32_t nb3 = dst->nb[3];
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (float *) octx->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) octx->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0) {
+        scale /= logit_softcap;
+    }
+
+    // total rows in q
+    const uint32_t nr = neq1*neq2*neq3;
+
+    const uint32_t dr = (nr + nth - 1) / nth;
+    const uint32_t ir0 = dr * ith;
+    const uint32_t ir1 = MIN(ir0 + dr, nr);
+
+    if (ir0 >= ir1) return;
+
+    dma_queue * dma = octx->ctx->dma[ith];
+
+    const uint32_t DK = nek0;
+    const uint32_t DV = nev0;
+
+    const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
+    const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
+
+    const size_t size_k_row = DK * sizeof(__fp16);
+    const size_t size_v_row = DV * sizeof(__fp16);
+    const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
+
+    const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
+    const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
+
+    const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
+    const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
+    const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
+
+    // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
+    uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
+    uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
+    uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
+    uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
+    uint8_t * spad_a = octx->dst_spad.data  + octx->dst_spad.size_per_thread  * ith;
+
+    const uint32_t n_head = neq2;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    for (uint32_t ir = ir0; ir < ir1; ++ir) {
+        const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
+        const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
+        const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
+
+        const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
+        const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
+
+        const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
+        const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
+
+        // Fetch Q row
+        const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
+        dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
+
+        const uint32_t h = iq2; // head index
+        const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
+
+        float S = 0.0f;      // sum
+        float M = -INFINITY; // maximum KQ value
+
+        // Clear accumulator
+        float * VKQ32 = (float *) spad_a;
+        memset(VKQ32, 0, DV * sizeof(float));
+
+        const __fp16 * mp_base = NULL;
+        if (mask) {
+            const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
+            const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
+            mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
+        }
+
+        const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
+
+        // Prefetch first two blocks
+        for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
+            const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
+            const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
+
+            // K
+            const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
+            uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
+            dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
+
+            // V
+            const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
+            uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
+            dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
+
+            // Mask
+            if (mask) {
+                const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
+                uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
+                // Mask is 1D contiguous for this row
+                dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
+            }
+        }
+
+        const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
+
+        for (uint32_t ib = 0; ib < n_blocks; ++ib) {
+            const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
+            const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
+
+            // Wait for DMA
+            uint8_t * k_base = dma_queue_pop(dma).dst; // K
+            uint8_t * v_base = dma_queue_pop(dma).dst; // V
+            __fp16  * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
+
+            // Inner loop processing the block from VTCM
+            uint32_t ic = 0;
+
+            // Process in blocks of 32 (VLEN_FP32)
+            for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
+                // 1. Compute scores
+                float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
+                for (int j = 0; j < VLEN_FP32; ++j) {
+                    const uint32_t cur_ic = ic + j;
+                    const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
+                    if (q->type == HTP_TYPE_F32) {
+                        hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
+                    } else {
+                        hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
+                    }
+                }
+
+                HVX_Vector scores = *(HVX_Vector *) scores_arr;
+
+                // 2. Softcap
+                if (logit_softcap != 0.0f) {
+                    scores = hvx_vec_tanh_fp32(scores);
+                    scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
+                    scores = Q6_Vsf_equals_Vqf32(scores);
+                }
+
+                // 3. Mask
+                if (mask) {
+                    const __fp16 * mp = m_base + ic;
+                    HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
+
+                    HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
+                    HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
+
+                    HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
+
+                    HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
+                    HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
+                    scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
+                    scores = Q6_Vsf_equals_Vqf32(scores);
+                }
+
+                // 4. Online Softmax Update
+                HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
+                float m_block = hvx_vec_get_fp32(v_max);
+
+                float M_old = M;
+                float M_new = (m_block > M) ? m_block : M;
+                M = M_new;
+
+                float ms = expf(M_old - M_new);
+
+                hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+                S = S * ms;
+
+                HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
+                HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
+                HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
+
+                HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
+                float p_sum = hvx_vec_get_fp32(p_sum_vec);
+                S += p_sum;
+
+                // 5. Accumulate V
+                float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
+                *(HVX_Vector*)p_arr = P;
+
+                for (int j = 0; j < VLEN_FP32; ++j) {
+                    const uint32_t cur_ic = ic + j;
+                    const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
+                    hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
+                }
+            }
+
+            // Leftover
+            for (; ic < current_block_size; ++ic) {
+                float s_val;
+                const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
+
+                if (q->type == HTP_TYPE_F32) {
+                    hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
+                } else {
+                    hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
+                }
+
+                if (logit_softcap != 0.0f) {
+                    s_val = logit_softcap * tanhf(s_val);
+                }
+
+                if (mask) {
+                    const float m_val = m_base[ic];
+                    s_val += slope * m_val;
+                }
+
+                const float Mold = M;
+                float ms = 1.0f;
+                float vs = 1.0f;
+
+                if (s_val > M) {
+                    M = s_val;
+                    ms = expf(Mold - M);
+                    hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+                } else {
+                    vs = expf(s_val - M);
+                }
+
+                const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
+
+                hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
+
+                S = S * ms + vs;
+            }
+
+            // Issue DMA for next+1 block (if exists)
+            if (ib + 2 < n_blocks) {
+                const uint32_t next_ib = ib + 2;
+                const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
+                const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
+
+                // K
+                const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
+                dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
+
+                // V
+                const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
+                dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
+
+                // Mask
+                if (mask) {
+                    const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
+                    dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
+                }
+            }
+        }
+
+        // sinks
+        if (sinks) {
+            const float s = ((float *)((char *) sinks->data))[h];
+
+            float ms = 1.0f;
+            float vs = 1.0f;
+
+            if (s > M) {
+                ms = expf(M - s);
+                hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+            } else {
+                vs = expf(s - M);
+            }
+
+            S = S * ms + vs;
+        }
+
+        const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
+        hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
+
+        // Store result
+        // dst indices
+        const int i1 = iq1;
+        const int i2 = iq2;
+        const int i3 = iq3;
+
+        // dst is permuted
+        uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
+
+        if (dst->type == HTP_TYPE_F32) {
+            hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
+        } else if (dst->type == HTP_TYPE_F16) {
+            hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
+        }
+    }
+}
+
+static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+    flash_attn_ext_f16_thread(octx, i, n);
+}
+
+int op_flash_attn_ext(struct htp_ops_context * octx) {
+    const struct htp_tensor * q = &octx->src0;
+    const struct htp_tensor * k = &octx->src1;
+    const struct htp_tensor * v = &octx->src2;
+    const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
+    struct htp_tensor * dst = &octx->dst;
+
+    // Check support
+    if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
+        k->type != HTP_TYPE_F16 ||
+        v->type != HTP_TYPE_F16) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
+    octx->src0_div1  = init_fastdiv_values(q->ne[1]);
+
+    octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
+    octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
+    octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
+    octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
+
+    if (mask) {
+        octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
+        octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
+    }
+
+    size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
+    size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
+    size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
+
+    size_t size_q_block = size_q_row_padded * 1; // single row for now
+    size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
+    size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
+    size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
+
+    size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
+
+    octx->src0_spad.size_per_thread = size_q_block * 1;
+    octx->src1_spad.size_per_thread = size_k_block * 2;
+    octx->src2_spad.size_per_thread = size_v_block * 2;
+    octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
+    octx->dst_spad.size_per_thread  = size_vkq_acc;
+
+    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+    octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
+    octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
+    octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
+    octx->dst_spad.size  = octx->dst_spad.size_per_thread  * octx->n_threads;
+
+    size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
+
+    if (octx->ctx->vtcm_size < total_spad) {
+        return HTP_STATUS_VTCM_TOO_SMALL;
+    }
+
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+    octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+    octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
+    octx->dst_spad.data  = octx->src3_spad.data + octx->src3_spad.size;
+
+    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
+    }
+
+    return HTP_STATUS_OK;
+}
diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c
new file mode 100644 (file)
index 0000000..5432142
--- /dev/null
@@ -0,0 +1,112 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+#    define FARF_HIGH 1
+#endif
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+#define get_rows_preamble \
+    const uint32_t ne00 = octx->src0.ne[0]; \
+    const uint32_t ne01 = octx->src0.ne[1]; \
+    const uint32_t ne02 = octx->src0.ne[2]; \
+    const uint32_t ne03 = octx->src0.ne[3]; \
+                                            \
+    const uint32_t ne10 = octx->src1.ne[0]; \
+    const uint32_t ne11 = octx->src1.ne[1]; \
+    const uint32_t ne12 = octx->src1.ne[2]; \
+                                            \
+    const uint32_t nb01 = octx->src0.nb[1]; \
+    const uint32_t nb02 = octx->src0.nb[2]; \
+    const uint32_t nb03 = octx->src0.nb[3]; \
+                                            \
+    const uint32_t nb10 = octx->src1.nb[0]; \
+    const uint32_t nb11 = octx->src1.nb[1]; \
+    const uint32_t nb12 = octx->src1.nb[2]; \
+                                            \
+    const uint32_t nb1 = octx->dst.nb[1];   \
+    const uint32_t nb2 = octx->dst.nb[2];   \
+    const uint32_t nb3 = octx->dst.nb[3];   \
+                                            \
+    const uint32_t nr = ne10 * ne11 * ne12;
+
+static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+    get_rows_preamble;
+
+    // parallelize by src1 elements (which correspond to dst rows)
+    const uint32_t dr  = octx->src1_nrows_per_thread;
+    const uint32_t ir0 = dr * ith;
+    const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
+
+    const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
+
+    for (uint32_t i = ir0; i < ir1; ++i) {
+        const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
+        const uint32_t rem = i - i12 * ne11 * ne10;
+        const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
+        const uint32_t i10 = rem - i11 * ne10;
+
+        const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
+
+        uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
+
+        if (i01 >= ne01) {
+            // invalid index, skip for now to avoid crash
+            continue;
+        }
+
+        const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
+        const uintptr_t dst_ptr  = octx->dst.data  + i10*nb1  + i11*nb2  + i12*nb3;
+        hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
+    }
+
+    return HTP_STATUS_OK;
+}
+
+static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
+    get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
+}
+
+int op_get_rows(struct htp_ops_context * octx) {
+    get_rows_preamble;
+
+    if (octx->src0.type != HTP_TYPE_F32) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    if (octx->dst.type != HTP_TYPE_F32) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+        return HTP_STATUS_OK;
+    }
+
+    octx->get_rows_div_ne10      = init_fastdiv_values(octx->src1.ne[0]);
+    octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
+
+    const uint32_t n_jobs = MIN(nr, octx->n_threads);
+    octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+
+    worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
+    return HTP_STATUS_OK;
+}
index 5c3d217f1ccf1da7179c21dd24f8d84b2d4414fc..4bd0ea7a36af3c08c0fb6adad0f95303507a0064 100644 (file)
 
 #define HTP_MAX_NTHREADS 10
 
-// FIXME: move these into matmul-ops
-#define HTP_SPAD_SRC0_NROWS 16
-#define HTP_SPAD_SRC1_NROWS 16
-#define HTP_SPAD_DST_NROWS  2
-
 // Main context for htp DSP backend
 struct htp_context {
     dspqueue_t            queue;
index a61652304aae01c6a5ce6ddabc4e1b2360d38087..846d06178438f13ec0554a9d467d310358019497 100644 (file)
@@ -36,6 +36,8 @@ enum htp_data_type {
     HTP_TYPE_F16   = 1,
     HTP_TYPE_Q4_0  = 2,
     HTP_TYPE_Q8_0  = 8,
+    HTP_TYPE_I32   = 26,
+    HTP_TYPE_I64   = 27,
     HTP_TYPE_MXFP4 = 39,
     HTP_TYPE_COUNT
 };
@@ -57,6 +59,10 @@ enum htp_op {
     HTP_OP_SOFTMAX        = 11,
     HTP_OP_ADD_ID         = 12,
     HTP_OP_ROPE           = 13,
+    HTP_OP_FLASH_ATTN_EXT = 14,
+    HTP_OP_SET_ROWS       = 15,
+    HTP_OP_SCALE          = 16,
+    HTP_OP_GET_ROWS       = 17,
     INVALID
 };
 
@@ -137,6 +143,8 @@ struct htp_general_req {
     struct htp_tensor src0;  // Input0 tensor
     struct htp_tensor src1;  // Input1 tensor
     struct htp_tensor src2;  // Input2 tensor
+    struct htp_tensor src3;  // Input3 tensor
+    struct htp_tensor src4;  // Input4 tensor
     struct htp_tensor dst;   // Output tensor
 
     // should be multiple of 64 bytes (cacheline)
@@ -152,6 +160,6 @@ struct htp_general_rsp {
 };
 
 #define HTP_MAX_MESSAGE_SIZE   sizeof(struct htp_general_req)
-#define HTP_MAX_PACKET_BUFFERS 4
+#define HTP_MAX_PACKET_BUFFERS 8
 
 #endif /* HTP_MSG_H */
index e87657436f08b9e486f0ca6260e690fc7414666e..7c828ae63620428d9773b589853cddd91d0ad9f3 100644 (file)
@@ -13,6 +13,7 @@
 
 struct htp_spad {
     uint8_t * data;
+    size_t    stride;
     size_t    size;
     size_t    size_per_thread;
 };
@@ -26,11 +27,14 @@ struct htp_ops_context {
     struct htp_tensor src0;
     struct htp_tensor src1;
     struct htp_tensor src2;
+    struct htp_tensor src3;
+    struct htp_tensor src4;
     struct htp_tensor dst;
 
     struct htp_spad src0_spad;
     struct htp_spad src1_spad;
     struct htp_spad src2_spad;
+    struct htp_spad src3_spad;
     struct htp_spad dst_spad;
 
     worker_pool_context_t * wpool;      // worker pool
@@ -49,6 +53,27 @@ struct htp_ops_context {
     struct fastdiv_values src1_div3;  // fastdiv values for ne3
     struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
 
+    struct fastdiv_values src3_div1;  // fastdiv values for ne1
+    struct fastdiv_values src3_div2;  // fastdiv values for ne2
+    struct fastdiv_values src3_div3;  // fastdiv values for ne3
+    struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
+
+    struct fastdiv_values broadcast_rk2;
+    struct fastdiv_values broadcast_rk3;
+    struct fastdiv_values broadcast_rv2;
+    struct fastdiv_values broadcast_rv3;
+
+    struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
+    struct fastdiv_values mm_div_ne1;      // fastdiv values for ne1
+    struct fastdiv_values mm_div_r2;       // fastdiv values for ne12 / ne02
+    struct fastdiv_values mm_div_r3;       // fastdiv values for ne13 / ne03
+
+    struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
+    struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
+
+    struct fastdiv_values get_rows_div_ne10;      // fastdiv values for ne10
+    struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
+
     uint32_t flags;
 };
 
@@ -60,5 +85,8 @@ int op_activations(struct htp_ops_context * octx);
 int op_softmax(struct htp_ops_context * octx);
 int op_add_id(struct htp_ops_context * octx);
 int op_rope(struct htp_ops_context * octx);
+int op_flash_attn_ext(struct htp_ops_context * octx);
+int op_set_rows(struct htp_ops_context * octx);
+int op_get_rows(struct htp_ops_context * octx);
 
 #endif /* HTP_OPS_H */
index f9e02ab67e53fa250f0fc6add5aec0e0a5e4d962..29d73b8622bae0734d1951562c31ca625ba4f044 100644 (file)
@@ -848,55 +848,6 @@ float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) {
     return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
 }
 
-void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) {
-    int left_over       = num_elems & (VLEN_FP32 - 1);
-    int num_elems_whole = num_elems - left_over;
-
-    int unaligned_addr = 0;
-    int unaligned_loop = 0;
-    if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
-        FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n");
-        unaligned_addr = 1;
-    }
-
-    if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
-        unaligned_loop = 1;
-        FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n");
-    }
-
-    HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
-
-    if (0 == unaligned_loop) {
-        HVX_Vector * vec_in1 = (HVX_Vector *) src;
-        HVX_Vector * vec_out = (HVX_Vector *) dst;
-
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec);
-            *vec_out++   = Q6_Vsf_equals_Vqf32(v);
-        }
-    } else {
-        #pragma unroll(4)
-        for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
-            HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
-
-            HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
-
-            *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
-        }
-    }
-
-    if (left_over > 0) {
-        const float * srcf = (const float *) src + num_elems_whole;
-        float *       dstf = (float *) dst + num_elems_whole;
-
-        HVX_Vector in = *(HVX_UVector *) srcf;
-
-        HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
-        hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
-    }
-}
-
 float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
     int left_over       = num_elems & (VLEN_FP32 - 1);
     int num_elems_whole = num_elems - left_over;
@@ -1065,3 +1016,5 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src,
         hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
     }
 }
+
+
index d2d5d23636689f58949692c26d0a3a2ba34912a7..22876e6dbaadebca62c9e28406ac3d9296dae44e 100644 (file)
@@ -41,15 +41,24 @@ static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
 }
 #endif
 
-static inline HVX_Vector hvx_vec_splat_fp32(float i) {
+static inline HVX_Vector hvx_vec_splat_fp32(float v) {
     union {
-        float   f;
-        int32_t i;
-    } fp32 = { .f = i };
+        float    f;
+        uint32_t i;
+    } fp32 = { .f = v };
 
     return Q6_V_vsplat_R(fp32.i);
 }
 
+static inline HVX_Vector hvx_vec_splat_fp16(float v) {
+    union {
+        __fp16   f;
+        uint16_t i;
+    } fp16 = { .f = v };
+
+    return Q6_Vh_vsplat_R(fp16.i);
+}
+
 static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) {
     // Rotate as needed.
     v = Q6_V_vlalign_VVR(v, v, (size_t) addr);
@@ -242,6 +251,120 @@ static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * rest
     }
 }
 
+// copy n fp32 elements : source is unaligned, destination unaligned
+static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    HVX_UVector * restrict vdst = (HVX_UVector *) dst;
+    HVX_UVector * restrict vsrc = (HVX_UVector *) src;
+
+    assert((unsigned long) dst % 128 == 0);
+
+    uint32_t nvec = n / 32;
+    uint32_t nloe = n % 32;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (; i < nvec; i++) {
+        HVX_Vector v = vsrc[i];
+        vdst[i]      = v;
+    }
+
+    if (nloe) {
+        HVX_Vector v = vsrc[i];
+        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
+    }
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
+static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
+    HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
+
+    const HVX_Vector zero = Q6_V_vsplat_R(0);
+
+    uint32_t nvec = n / 64;
+    uint32_t nloe = n % 64;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (; i < nvec; i++) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+        vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
+    }
+
+    if (nloe) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
+    }
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
+static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
+    HVX_Vector  * restrict vsrc = (HVX_Vector *)  src; // fp32
+
+    const HVX_Vector zero = Q6_V_vsplat_R(0);
+
+    uint32_t nvec = n / 64;
+    uint32_t nloe = n % 64;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (; i < nvec; i++) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+        vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
+    }
+
+    if (nloe) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
+    }
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
+static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    HVX_Vector  * restrict vdst = (HVX_Vector *)  dst; // fp16
+    HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
+
+    const HVX_Vector zero = Q6_V_vsplat_R(0);
+
+    uint32_t nvec = n / 64;
+    uint32_t nloe = n % 64;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (; i < nvec; i++) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+        vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
+    }
+
+    if (nloe) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+        HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+        HVX_Vector s_hf  = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+        hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
+    }
+}
+
 // bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned
 static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) {
     HVX_Vector * restrict vdst = (HVX_Vector *) dst;
@@ -273,8 +396,6 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3
     return right_off <= chunk_size;
 }
 
-
-
 static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
     HVX_VectorAlias u = { .v = v };
 
@@ -531,13 +652,13 @@ static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
 }
 
 static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
-#if __HTP_ARCH__ > 75
+#if __HVX_ARCH__ > 75
     return Q6_Vsf_vfneg_Vsf(v);
 #else
     // neg by setting the fp32 sign bit
     HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
     return Q6_V_vxor_VV(v, mask);
-#endif  // __HTP_ARCH__ > 75
+#endif  // __HVX_ARCH__ > 75
 }
 
 // ====================================================
@@ -976,6 +1097,24 @@ static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v,
     return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
 }
 
+static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) {
+    // tanh(x) = 2 * sigmoid(2x) - 1
+    HVX_Vector two = hvx_vec_splat_fp32(2.0f);
+    HVX_Vector one = hvx_vec_splat_fp32(1.0f);
+    HVX_Vector x2  = Q6_Vqf32_vmpy_VsfVsf(x, two);
+
+    static const float kMinExp = -87.f;  // 0
+    static const float kMaxExp = 87.f;   // 1
+    HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
+    HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
+
+    HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
+
+    HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
+    res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
+    return Q6_Vsf_equals_Vqf32(res);
+}
+
 static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
     int step_of_1 = num_elems >> 5;
     int remaining = num_elems - step_of_1 * VLEN_FP32;
@@ -1056,6 +1195,115 @@ static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restr
     }
 }
 
+static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+    int nvec = n / VLEN_FP32;
+    int nloe = n % VLEN_FP32;
+
+    HVX_Vector vs = hvx_vec_splat_fp32(scale);
+
+    HVX_Vector * vsrc = (HVX_Vector *) src;
+    HVX_Vector * vdst = (HVX_Vector *) dst;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (i = 0; i < nvec; ++i) {
+        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
+        vdst[i]      = Q6_Vsf_equals_Vqf32(v);
+    }
+
+    if (nloe) {
+        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
+        hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
+    }
+}
+
+static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+    int nvec = n / VLEN_FP32;
+    int nloe = n % VLEN_FP32;
+
+    HVX_Vector vs = hvx_vec_splat_fp32(scale);
+
+    HVX_UVector * vsrc = (HVX_UVector *) src;
+    HVX_UVector * vdst = (HVX_UVector *) dst;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (i = 0; i < nvec; ++i) {
+        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
+        vdst[i]      = Q6_Vsf_equals_Vqf32(v);
+    }
+
+    if (nloe) {
+        HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
+        hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
+    }
+}
+
+static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+    if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
+        hvx_scale_f32_aa(dst, src, n, scale);
+    } else {
+        hvx_scale_f32_uu(dst, src, n, scale);
+    }
+}
+
+static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+    int nvec = n / VLEN_FP32;
+    int nloe = n % VLEN_FP32;
+
+    HVX_Vector vs = hvx_vec_splat_fp32(scale);
+    HVX_Vector vo = hvx_vec_splat_fp32(offset);
+
+    HVX_Vector * vsrc = (HVX_Vector *) src;
+    HVX_Vector * vdst = (HVX_Vector *) dst;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (i = 0; i < nvec; ++i) {
+        HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
+        vdst[i] = Q6_Vsf_equals_Vqf32(v);
+    }
+
+    if (nloe) {
+        HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
+        hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
+    }
+}
+
+static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+    int nvec = n / VLEN_FP32;
+    int nloe = n % VLEN_FP32;
+
+    HVX_Vector vs = hvx_vec_splat_fp32(scale);
+    HVX_Vector vo = hvx_vec_splat_fp32(offset);
+
+    HVX_UVector * vsrc = (HVX_UVector *) src;
+    HVX_UVector * vdst = (HVX_UVector *) dst;
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (i = 0; i < nvec; ++i) {
+        HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
+        vdst[i] = Q6_Vsf_equals_Vqf32(v);
+    }
+
+    if (nloe) {
+        HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
+        hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
+    }
+}
+
+static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+    if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
+        hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
+    } else {
+        hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
+    }
+}
 
 float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
 void  hvx_mul_f32(const uint8_t * restrict src0,
@@ -1090,7 +1338,6 @@ void  hvx_sub_f32_opt(const uint8_t * restrict src0,
                       uint8_t * restrict dst,
                       const int num_elems);
 void  hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
-void  hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale);
 void  hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
 void  hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
 void  hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate);
index fb5508a560f30b9693d01f9ce37f16cda8aaf43e..24b3e90e4b6903a579adf6753144cbb40edf533b 100644 (file)
@@ -443,6 +443,45 @@ static void proc_matmul_req(struct htp_context *     ctx,
     send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
 }
 
+static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[1];
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[0].fd     = bufs[2].fd;
+    rsp_bufs[0].ptr    = bufs[2].ptr;
+    rsp_bufs[0].offset = bufs[2].offset;
+    rsp_bufs[0].size   = bufs[2].size;
+    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.src1                   = req->src1;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.src1.data = (uint32_t) bufs[1].ptr;
+    octx.dst.data  = (uint32_t) bufs[2].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_get_rows(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
 static void proc_matmul_id_req(struct htp_context *     ctx,
                                struct htp_general_req * req,
                                struct dspqueue_buffer * bufs,
@@ -668,7 +707,7 @@ static void proc_rope_req(struct htp_context *     ctx,
                           uint32_t                 n_bufs) {
     struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
 
-    int write_idx = (n_bufs == 4) ? 3 : 2;
+    int write_idx = n_bufs - 1;
 
     // We had written to the output buffer, we'd also need to flush it
     rsp_bufs[0].fd     = bufs[write_idx].fd;
@@ -716,6 +755,102 @@ static void proc_rope_req(struct htp_context *     ctx,
     send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
 }
 
+static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[1];
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[0].fd     = bufs[2].fd;
+    rsp_bufs[0].ptr    = bufs[2].ptr;
+    rsp_bufs[0].offset = bufs[2].offset;
+    rsp_bufs[0].size   = bufs[2].size;
+    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.src1                   = req->src1;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.src1.data = (uint32_t) bufs[1].ptr;
+    octx.dst.data  = (uint32_t) bufs[2].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_set_rows(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_flash_attn_ext_req(struct htp_context *     ctx,
+                                    struct htp_general_req * req,
+                                    struct dspqueue_buffer * bufs,
+                                    uint32_t                 n_bufs) {
+    // Setup Op context
+    struct htp_ops_context octx;
+    memset(&octx, 0, sizeof(octx));
+
+    octx.ctx   = ctx;
+    octx.n_threads = ctx->n_threads;
+
+    octx.src0  = req->src0;
+    octx.src1  = req->src1;
+    octx.src2  = req->src2;
+    octx.src3  = req->src3;
+    octx.src4  = req->src4;
+    octx.dst   = req->dst;
+    octx.flags = req->flags;
+    octx.op    = req->op;
+
+    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.src1.data = (uint32_t) bufs[1].ptr;
+    octx.src2.data = (uint32_t) bufs[2].ptr;
+
+    int last_buf = 3;
+
+    if (octx.src3.ne[0]) {
+        octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid
+    }
+
+    if (octx.src4.ne[0]) {
+        octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid
+    }
+
+    octx.dst.data = (uint32_t) bufs[last_buf].ptr;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_flash_attn_ext(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+
+    struct dspqueue_buffer rsp_buf = bufs[last_buf];
+    rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP
+                     DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+    send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
+}
+
 static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
     struct htp_context * ctx = (struct htp_context *) context;
 
@@ -790,6 +925,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
                 break;
 
             case HTP_OP_RMS_NORM:
+            case HTP_OP_SCALE:
                 if (n_bufs != 2) {
                     FARF(ERROR, "Bad unary-req buffer list");
                     continue;
@@ -833,6 +969,30 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
                 proc_rope_req(ctx, &req, bufs, n_bufs);
                 break;
 
+            case HTP_OP_FLASH_ATTN_EXT:
+                if (!(n_bufs >= 4 && n_bufs <= 6)) {
+                    FARF(ERROR, "Bad flash-attn-ext-req buffer list");
+                    continue;
+                }
+                proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs);
+                break;
+
+            case HTP_OP_SET_ROWS:
+                if (n_bufs != 3) {
+                    FARF(ERROR, "Bad set-rows-req buffer list");
+                    continue;
+                }
+                proc_set_rows_req(ctx, &req, bufs);
+                break;
+
+            case HTP_OP_GET_ROWS:
+                if (n_bufs != 3) {
+                    FARF(ERROR, "Bad get-rows-req buffer list");
+                    continue;
+                }
+                proc_get_rows_req(ctx, &req, bufs);
+                break;
+
             default:
                 FARF(ERROR, "Unknown Op %u", req.op);
                 break;
index f14523d485cf43ce82e2feaa63ca183a584cf65b..9bb39db9fcbd9dabd644678a40d101adb86efd18 100644 (file)
 #include "hvx-utils.h"
 #include "ops-utils.h"
 
+#define MM_SPAD_SRC0_NROWS 16
+#define MM_SPAD_SRC1_NROWS 16
+#define MM_SPAD_DST_NROWS  2
+
 struct htp_matmul_type {
     const char * type;
     void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-    void (*vec_dot_rx2)(const int n,
-                        float * restrict s,
-                        const void * restrict vx,
-                        uint32_t vx_row_size,
-                        const void * restrict vy);
+    void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
 };
 
 typedef struct {
@@ -907,145 +907,174 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
     hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
 }
 
-#if 1
-static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
-    if (0) {
-        float rsum                 = 0;
-        const __fp16 * restrict vx = (const __fp16 * restrict) x;
-        const float * restrict vy  = (const float * restrict) y;
+static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const HVX_Vector * restrict x = (const HVX_Vector *) vx;
+    const HVX_Vector * restrict y = (const HVX_Vector *) vy;
 
-        for (uint32_t i = 0; i < n; i++) {
-            rsum += (float)vx[i] * vy[i];
-        }
-        *s = rsum;
-        return;
+    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+    uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+    HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (i = 0; i < nvec; i++) {
+        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
+        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
     }
 
-    const HVX_UVector * restrict vx     = (const HVX_UVector * restrict) x;
-    const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y;
+    if (nloe) {
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+        HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
+        HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
 
-    uint32_t nv0 = n / 64;  // num full fp16 hvx vectors
-    uint32_t nv1 = n % 64;  // leftover elements
+        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
+    }
+
+    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+    hvx_vec_store_u(&s[0], 4, rsum);
+}
+
+static void vec_dot_f16_f16_aa_rx2(const int n,
+                                float * restrict s,
+                                const void * restrict vx,
+                                uint32_t vx_row_size,
+                                const void * restrict vy) {
+    const HVX_Vector * restrict x0 = (const HVX_Vector *) vx;
+    const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size);
+    const HVX_Vector * restrict y  = (const HVX_Vector *) vy;
+
+    uint32_t nvec = n / VLEN_FP16;
+    uint32_t nloe = n % VLEN_FP16;
+
+    HVX_Vector rsum0 = Q6_V_vsplat_R(0);
+    HVX_Vector rsum1 = Q6_V_vsplat_R(0);
 
-    // for some reason we need volatile here so that the compiler doesn't try anything funky
-    volatile HVX_Vector rsum = Q6_V_vsplat_R(0);
-    float r_sum_scalar = 0.0f;
     uint32_t i = 0;
 
-    for (i = 0; i < nv0; i++) {
-        HVX_VectorPair yp = vy[i];
+    #pragma unroll(2)
+    for (i = 0; i < nvec; i++) {
+        HVX_Vector y_hf = y[i];
+        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf);
+        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf);
 
-        HVX_Vector     x  = vx[i];
-        HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00));  // mul by 1.0
+        rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
+        rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
+    }
 
-        //NOTE: need volatile here to prevent compiler optimization
-        // Seem compiler cannot guarantee read-after-write??
-        volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
-        volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
+    if (nloe) {
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+        HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
+        HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
+        HVX_Vector y_hf  = Q6_V_vand_QV(bmask, y[i]);
+
+        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
 
-        HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo);
-        rsum           = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
+        rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
+        rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
     }
 
-    if (nv1) {
-        // HVX_VectorPair yp = vy[i];
+    rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0));
+    rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1));
+    HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4);
 
-        // HVX_Vector     x  = vx[i];
-        // HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00));  // mul by 1.0
+    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+}
 
-        // if (nv1 >= 32) {
-        //     volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
-        //     rsum          = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
-        //     nv1 -= 32;
-        // }
+static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const HVX_UVector * restrict x = (const HVX_UVector *) vx;
+    const HVX_UVector * restrict y = (const HVX_UVector *) vy;
 
-        // rsum = hvx_vec_qf32_reduce_sum(rsum);
+    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+    uint32_t nloe = n % VLEN_FP16; // leftover elements
 
-        // if (nv1) {
-        //     volatile HVX_Vector lo  = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
-        //     HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
-        //     rsum           = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
-        // }
+    HVX_Vector rsum = Q6_V_vsplat_R(0);
 
-        //process the remainder using scalar loop
-        rsum = hvx_vec_qf32_reduce_sum(rsum);
-        const __fp16 * restrict sx = (const __fp16 * restrict) x;
-        const float * restrict sy  = (const float * restrict) y;
+    uint32_t i = 0;
 
-        for (uint32_t i = nv0 * 64; i < n; i++) {
-            r_sum_scalar += (float) sx[i] * sy[i];
-        }
+    #pragma unroll(4)
+    for (i = 0; i < nvec; i++) {
+        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
+        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
+    }
+
+    if (nloe) {
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+        HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
+        HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
 
-        // hvx_vec_dump_fp16("X", x);
-        // hvx_vec_dump_fp16("Y", y);
-        // hvx_vec_dump_fp32("SUM",  Q6_Vsf_equals_Vqf32(sum));
-        // hvx_vec_dump_fp32("RSUM", Q6_Vsf_equals_Vqf32(rsum));
-    } else {
-        rsum = hvx_vec_qf32_reduce_sum(rsum);
+        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
     }
 
-    *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar;
+    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+    hvx_vec_store_u(&s[0], 4, rsum);
+}
 
-#    ifdef HTP_DEBUG
-    {
-        float rsum                 = 0;
-        const __fp16 * restrict vx = (const __fp16 * restrict) x;
-        const float * restrict vy  = (const float * restrict) y;
+static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+    const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
+    const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
 
-        for (uint32_t i = 0; i < n; i++) {
-            rsum += vx[i] * vy[i];
-        }
+    uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+    uint32_t nloe = n % VLEN_FP16; // leftover elements
 
-        float diff = fabs(*s - rsum);
-        if (diff > 0.001) {
-            FARF(HIGH, "vec-dot-f16-missmatch: %u (%u:%u) expected %.6f got %.6f\n", n, nv0, nv1, rsum, *s);
-            // htp_dump_f16("x", vx, n);
-            // htp_dump_f32("y", vy, n);
-        }
+    const HVX_Vector zero = Q6_V_vsplat_R(0);
+
+    HVX_Vector       rsum = Q6_V_vsplat_R(0);
+
+    uint32_t i = 0;
+
+    #pragma unroll(2)
+    for (i = 0; i < nvec; i++) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements
+        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements
+        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+
+        // Load x (fp16)
+        HVX_Vector x_hf  = vx[i];
+
+        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
     }
-#    endif
-}
-#else
-static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
-    const uint32_t fk = 64;
-    const uint32_t nb = n / fk;
 
-    assert(n % fk == 0);
-    assert(nb % 4 == 0);
+    if (nloe) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements
+        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements
+        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
 
-    const uint32_t x_blk_size = 2 * fk;  // fp16
-    const uint32_t y_blk_size = 4 * fk;  // fp32
+        // Load x (fp16)
+        HVX_Vector x_hf  = vx[i];
 
-    // Row sum (qf32)
-    HVX_Vector rsum0 = Q6_V_vsplat_R(0);
-    HVX_Vector rsum1 = Q6_V_vsplat_R(0);
-    HVX_Vector rsum2 = Q6_V_vsplat_R(0);
-    HVX_Vector rsum3 = Q6_V_vsplat_R(0);
-
-    for (uint32_t i = 0; i < nb; i += 4) {
-        HVX_Vector_x4 vx = hvx_vec_load_x4_f16(x + (i * x_blk_size));
-        HVX_Vector_x4 vy = hvx_vec_load_x4_f32_as_f16(y + (i * y_blk_size));
-
-        HVX_VectorPair fa0 = Q6_Wqf32_vmpy_VhfVhf(vx.v[0], vy.v[0]);
-        HVX_VectorPair fa1 = Q6_Wqf32_vmpy_VhfVhf(vx.v[1], vy.v[1]);
-        HVX_VectorPair fa2 = Q6_Wqf32_vmpy_VhfVhf(vx.v[2], vy.v[2]);
-        HVX_VectorPair fa3 = Q6_Wqf32_vmpy_VhfVhf(vx.v[3], vy.v[3]);
-
-        rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa0), Q6_V_hi_W(fa0)));
-        rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa1), Q6_V_hi_W(fa1)));
-        rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa2), Q6_V_hi_W(fa2)));
-        rsum3 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum3, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa3), Q6_V_hi_W(fa3)));
+        // Zero-out unused elements
+        // Note that we need to clear both x and y because they may contain NANs
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+        x_hf = Q6_V_vand_QV(bmask, x_hf);
+        y_hf = Q6_V_vand_QV(bmask, y_hf);
+
+        HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
     }
 
-    // Reduce and convert into fp32
-    rsum0           = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum1);
-    rsum2           = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, rsum3);
-    HVX_Vector rsum = hvx_vec_qf32_reduce_sum(Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum2));
-    hvx_vec_store_u(s, 4, Q6_Vsf_equals_Vqf32(rsum));
+    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+    hvx_vec_store_u(&s[0], 4, rsum);
 }
-#endif
 
-#define htp_matmul_preamble            \
+#define htp_matmul_tensors_preamble    \
+    struct htp_tensor * restrict src0    = &octx->src0;      \
+    struct htp_tensor * restrict src1    = &octx->src1;      \
+    struct htp_tensor * restrict src2    = &octx->src2;      \
+    struct htp_tensor * restrict dst     = &octx->dst;       \
+    struct htp_spad * restrict src0_spad = &octx->src0_spad; \
+    struct htp_spad * restrict src1_spad = &octx->src1_spad; \
+    struct htp_spad * restrict dst_spad  = &octx->dst_spad;  \
+                                                             \
     const uint32_t ne00 = src0->ne[0]; \
     const uint32_t ne01 = src0->ne[1]; \
     const uint32_t ne02 = src0->ne[2]; \
@@ -1056,6 +1085,11 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
     const uint32_t ne12 = src1->ne[2]; \
     const uint32_t ne13 = src1->ne[3]; \
                                        \
+    const uint32_t ne20 = src2->ne[0]; \
+    const uint32_t ne21 = src2->ne[1]; \
+    const uint32_t ne22 = src2->ne[2]; \
+    const uint32_t ne23 = src2->ne[3]; \
+                                       \
     const uint32_t ne0 = dst->ne[0];   \
     const uint32_t ne1 = dst->ne[1];   \
     const uint32_t ne2 = dst->ne[2];   \
@@ -1076,18 +1110,94 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
     const uint32_t nb2 = dst->nb[2];   \
     const uint32_t nb3 = dst->nb[3];
 
-// q8x4 src1 tensor is already in VTCM spad
-static void matmul(struct htp_matmul_type * mt,
-                   struct htp_tensor * restrict src0,
-                   struct htp_tensor * restrict src1,
-                   struct htp_tensor * restrict dst,
-                   struct htp_spad * restrict src0_spad,
-                   struct htp_spad * restrict src1_spad,
-                   struct htp_spad * restrict dst_spad,
-                   uint32_t    nth,
-                   uint32_t    ith,
-                   uint32_t    src0_nrows_per_thread,
-                   dma_queue * dma_queue) {
+#define htp_matmul_preamble            \
+    htp_matmul_tensors_preamble;       \
+    dma_queue *dma_queue           = octx->ctx->dma[ith];         \
+    uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+
+// *** matmul with support for 4d tensors and full broadcasting
+
+static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+    htp_matmul_preamble;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    assert(ne12 % ne02 == 0);
+    assert(ne13 % ne03 == 0);
+
+    // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
+    const uint32_t nr0 = ne0;
+
+    // This is the size of the rest of the dimensions of the result
+    const uint32_t nr1 = ne1 * ne2 * ne3;
+
+    // distribute the thread work across the inner or outer loop based on which one is larger
+    uint32_t nchunk0 = nr0 > nr1 ? nth : 1;  // parallelize by src0 rows
+    uint32_t nchunk1 = nr0 > nr1 ? 1 : nth;  // parallelize by src1 rows
+
+    // The number of elements in each chunk
+    const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
+    const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
+
+    uint32_t current_chunk = ith;
+
+    const uint32_t ith0 = current_chunk % nchunk0;
+    const uint32_t ith1 = current_chunk / nchunk0;
+
+    const uint32_t ir0_start = dr0 * ith0;
+    const uint32_t ir0_end   = MIN(ir0_start + dr0, nr0);
+
+    const uint32_t ir1_start = dr1 * ith1;
+    const uint32_t ir1_end   = MIN(ir1_start + dr1, nr1);
+
+    // no work for this thread
+    if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
+        return;
+    }
+
+    // block-tiling attempt
+    const uint32_t blck_0 = 64;
+    const uint32_t blck_1 = 64;
+
+    for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
+        for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
+            for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
+                const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1);
+                const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1);
+                const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
+
+                // broadcast src0 into src1
+                const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3);
+                const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2);
+
+                const uint32_t i1 = i11;
+                const uint32_t i2 = i12;
+                const uint32_t i3 = i13;
+
+                const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
+                const uint8_t * restrict src1_col  = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
+                float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
+
+                const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
+                for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
+                    const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
+                    mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col);
+                }
+            }
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
+         src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// src1 tensor is already in VTCM spad
+static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
     htp_matmul_preamble;
 
     const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
@@ -1104,9 +1214,10 @@ static void matmul(struct htp_matmul_type * mt,
 
     const size_t dst_row_size  = nb1;
     const size_t src0_row_size = nb01;
-    const size_t src1_row_size = q8x4x2_row_size(ne10);
+    const size_t src1_row_size = nb11;
 
-    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+    const size_t src0_stride = src0_spad->stride;
+    const size_t src1_stride = src1_spad->stride;
 
     // Per-thread VTCM scratchpads for all tensors
     // Note that the entire src1 tensor is already in VTCM
@@ -1124,11 +1235,11 @@ static void matmul(struct htp_matmul_type * mt,
     #pragma unroll(4)
     for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
         const int is0 = (ir0 - src0_start_row);
-        if (is0 >= HTP_SPAD_SRC0_NROWS) {
+        if (is0 >= MM_SPAD_SRC0_NROWS) {
             break;
         }
-        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
-                       src0_row_size_padded, src0_row_size, 2);
+        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+                       src0_stride, src0_row_size, 2);
     }
 
     // Process src0 rows
@@ -1137,17 +1248,17 @@ static void matmul(struct htp_matmul_type * mt,
 
         #pragma unroll(2)
         for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
-            const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
+            const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
             float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));
-            mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+            mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col);
         }
 
         // Prefetch next (n + spad_nrows) row
-        const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
-        const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+        const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+        const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
         if (pr0 < src0_end_row_x2) {
-            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
-                           src0_row_size_padded, src0_row_size, 2);
+            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
+                           src0_stride, src0_row_size, 2);
         }
     }
 
@@ -1155,13 +1266,13 @@ static void matmul(struct htp_matmul_type * mt,
     if (src0_end_row != src0_end_row_x2) {
         uint32_t  ir0 = src0_end_row_x2;
         const int is0 = (ir0 - src0_start_row);
-        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
-                       src0_row_size_padded, src0_row_size, 1);
+        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+                       src0_stride, src0_row_size, 1);
         const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
 
         #pragma unroll(2)
         for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
-            const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
+            const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
             float * restrict dst_row          = (float *) (dst->data + (ir1 * dst_row_size));
             mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
         }
@@ -1176,17 +1287,7 @@ static void matmul(struct htp_matmul_type * mt,
 }
 
 // q8x4x2 src1 tensor is already in VTCM spad
-static void matvec(struct htp_matmul_type * mt,
-                   struct htp_tensor * restrict src0,
-                   struct htp_tensor * restrict src1,
-                   struct htp_tensor * restrict dst,
-                   struct htp_spad * restrict src0_spad,
-                   struct htp_spad * restrict src1_spad,
-                   struct htp_spad * restrict dst_spad,
-                   uint32_t    nth,
-                   uint32_t    ith,
-                   uint32_t    src0_nrows_per_thread,
-                   dma_queue * dma_queue) {
+static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
     htp_matmul_preamble;
 
     const uint32_t src0_nrows = ne01;
@@ -1202,9 +1303,10 @@ static void matvec(struct htp_matmul_type * mt,
 
     const size_t dst_row_size  = nb1;
     const size_t src0_row_size = nb01;
-    const size_t src1_row_size = q8x4x2_row_size(ne10);
+    const size_t src1_row_size = nb11;
 
-    const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+    const size_t src0_stride = src0_spad->stride;
+    const size_t src1_stride = src1_spad->stride;
 
     // Per-thread VTCM scratchpads for all tensors
     // Note that the entire src1 tensor is already in VTCM
@@ -1226,24 +1328,24 @@ static void matvec(struct htp_matmul_type * mt,
     #pragma unroll(2)
     for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
         const uint32_t is0 = (ir0 - src0_start_row);
-        if (is0 >= HTP_SPAD_SRC0_NROWS) {
+        if (is0 >= MM_SPAD_SRC0_NROWS) {
             break;
         }
-        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
-                       src0_row_size_padded, src0_row_size, 2);
+        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+                       src0_stride, src0_row_size, 2);
     }
 
     // Process src0 rows
     for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
         const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
-        mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_row_size_padded, src1_col);
+        mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col);
 
         // Prefetch next (n + spad_nrows) row
-        const uint32_t pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
-        const uint32_t is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+        const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+        const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
         if (pr0 < src0_end_row_x2) {
-            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
-                           src0_row_size_padded, src0_row_size, 2);
+            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
+                           src0_stride, src0_row_size, 2);
         }
     }
 
@@ -1251,8 +1353,8 @@ static void matvec(struct htp_matmul_type * mt,
     if (src0_end_row != src0_end_row_x2) {
         const uint32_t ir0 = src0_end_row_x2;
         const uint32_t is0 = (ir0 - src0_start_row);
-        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
-                       src0_row_size_padded, src0_row_size, 1);
+        dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+                       src0_stride, src0_row_size, 1);
         const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
         mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
     }
@@ -1274,22 +1376,13 @@ struct mmid_row_mapping {
     uint32_t i2;
 };
 
-// q8x4 src1 tensor is already in VTCM spad
-static void matmul_id(struct htp_matmul_type * mt,
-                      struct htp_tensor * restrict src0,
-                      struct htp_tensor * restrict src1,
-                      struct htp_tensor * restrict ids,
-                      struct htp_tensor * restrict dst,
-                      struct htp_spad * restrict src0_spad,
-                      struct htp_spad * restrict src1_spad,
-                      struct htp_spad * restrict src2_spad,
-                      struct htp_spad * restrict dst_spad,
-                      uint32_t    nth,
-                      uint32_t    ith,
-                      uint32_t    src0_nrows_per_thread,
-                      dma_queue * dma_queue) {
+// src1 tensor is already in VTCM spad
+static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
     htp_matmul_preamble;
 
+    struct htp_tensor * restrict     ids = &octx->src2;
+    struct htp_spad * restrict src2_spad = &octx->src2_spad;
+
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
@@ -1340,7 +1433,7 @@ static void matmul_id(struct htp_matmul_type * mt,
         #pragma unroll(4)
         for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
             const int is0 = (ir0 - src0_start_row);
-            if (is0 >= HTP_SPAD_SRC0_NROWS) {
+            if (is0 >= MM_SPAD_SRC0_NROWS) {
                 break;
             }
             dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
@@ -1365,8 +1458,8 @@ static void matmul_id(struct htp_matmul_type * mt,
             }
 
             // Prefetch next (n + spad_nrows) row
-            const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
-            const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+            const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+            const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
             if (pr0 < src0_end_row_x2) {
                 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
                                src0_row_size_padded, src0_row_size, 2);
@@ -1404,22 +1497,13 @@ static void matmul_id(struct htp_matmul_type * mt,
          dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-// q8x4 src1 tensor is already in VTCM spad
-static void matvec_id(struct htp_matmul_type * mt,
-                      struct htp_tensor * restrict src0,
-                      struct htp_tensor * restrict src1,
-                      struct htp_tensor * restrict src2,
-                      struct htp_tensor * restrict dst,
-                      struct htp_spad * restrict src0_spad,
-                      struct htp_spad * restrict src1_spad,
-                      struct htp_spad * restrict src2_spad,
-                      struct htp_spad * restrict dst_spad,
-                      uint32_t    nth,
-                      uint32_t    ith,
-                      uint32_t    src0_nrows_per_thread,
-                      dma_queue * dma_queue) {
+// src1 tensor is already in VTCM spad
+static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
     htp_matmul_preamble;
 
+    struct htp_tensor * restrict     ids = &octx->src2;
+    struct htp_spad * restrict src2_spad = &octx->src2_spad;
+
     uint64_t t1, t2;
     t1 = HAP_perf_get_qtimer_count();
 
@@ -1464,7 +1548,7 @@ static void matvec_id(struct htp_matmul_type * mt,
         #pragma unroll(4)
         for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
             const int is0 = (ir0 - src0_start_row);
-            if (is0 >= HTP_SPAD_SRC0_NROWS) {
+            if (is0 >= MM_SPAD_SRC0_NROWS) {
                 break;
             }
             dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
@@ -1477,8 +1561,8 @@ static void matvec_id(struct htp_matmul_type * mt,
             mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
 
             // Prefetch next (n + spad_nrows) row
-            const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
-            const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+            const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+            const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
             if (pr0 < src0_end_row_x2) {
                 dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
                                src0_row_size_padded, src0_row_size, 2);
@@ -1504,106 +1588,6 @@ static void matvec_id(struct htp_matmul_type * mt,
          dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
-// *** matmul in fp16
-
-static void matmul_f16_f32(struct htp_tensor * restrict src0,
-                           struct htp_tensor * restrict src1,
-                           struct htp_tensor * restrict dst,
-                           struct htp_spad * restrict src0_spad,
-                           struct htp_spad * restrict src1_spad,
-                           struct htp_spad * restrict dst_spad,
-                           uint32_t    nth,
-                           uint32_t    ith,
-                           uint32_t    src0_nrows_per_thread,
-                           dma_queue * dma_queue) {
-    htp_matmul_preamble;
-
-    uint64_t t1, t2;
-    t1 = HAP_perf_get_qtimer_count();
-
-    assert(ne12 % ne02 == 0);
-    assert(ne13 % ne03 == 0);
-
-    // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
-    const uint32_t nr0 = ne0;
-
-    // This is the size of the rest of the dimensions of the result
-    const uint32_t nr1 = ne1 * ne2 * ne3;
-
-    // distribute the thread work across the inner or outer loop based on which one is larger
-    uint32_t nchunk0 = nr0 > nr1 ? nth : 1;  // parallelize by src0 rows
-    uint32_t nchunk1 = nr0 > nr1 ? 1 : nth;  // parallelize by src1 rows
-
-    // The number of elements in each chunk
-    const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
-    const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
-
-    uint32_t current_chunk = ith;
-
-    const uint32_t ith0 = current_chunk % nchunk0;
-    const uint32_t ith1 = current_chunk / nchunk0;
-
-    const uint32_t ir0_start = dr0 * ith0;
-    const uint32_t ir0_end   = MIN(ir0_start + dr0, nr0);
-
-    const uint32_t ir1_start = dr1 * ith1;
-    const uint32_t ir1_end   = MIN(ir1_start + dr1, nr1);
-
-    // broadcast factors
-    const uint32_t r2 = ne12 / ne02;
-    const uint32_t r3 = ne13 / ne03;
-
-    // no work for this thread
-    if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
-        return;
-    }
-
-    // block-tiling attempt
-    const uint32_t blck_0 = 64;
-    const uint32_t blck_1 = 64;
-
-    __attribute__((aligned(128))) float tmp[64];
-
-    for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
-        for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
-            for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
-                const uint32_t i13 = (ir1 / (ne12 * ne1));
-                const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
-                const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
-
-                // broadcast src0 into src1
-                const uint32_t i03 = i13 / r3;
-                const uint32_t i02 = i12 / r2;
-
-                const uint32_t i1 = i11;
-                const uint32_t i2 = i12;
-                const uint32_t i3 = i13;
-
-                const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
-                const uint8_t * restrict src1_col =
-                    (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
-                float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
-
-                const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
-                for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
-                    // Use nb01 stride for non-contiguous src0 support
-                    const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
-                    vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col);
-                }
-
-                hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0);
-            }
-        }
-    }
-
-    t2 = HAP_perf_get_qtimer_count();
-
-    FARF(HIGH, "matmul-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
-         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
-         src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
-         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
-}
-
 // *** dynamic quant
 
 static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
@@ -1780,20 +1764,14 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u
 
     for (uint32_t i = 0; i < nb; i++) {
 #if FP32_QUANTIZE_GROUP_SIZE == 32
-        quantize_block_fp32_q8x1(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
-                                 t_d + (i * 2 + 0) * dblk_size / 2);
-        quantize_block_fp32_q8x1(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
-                                 t_d + (i * 2 + 1) * dblk_size / 2);
+        quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+        quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
 #elif FP32_QUANTIZE_GROUP_SIZE == 64
-        quantize_block_fp32_q8x2(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
-                                 t_d + (i * 2 + 0) * dblk_size / 2);
-        quantize_block_fp32_q8x2(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
-                                 t_d + (i * 2 + 1) * dblk_size / 2);
+        quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+        quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
 #elif FP32_QUANTIZE_GROUP_SIZE == 128
-        quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
-                                 t_d + (i * 2 + 0) * dblk_size / 2);
-        quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
-                                 t_d + (i * 2 + 1) * dblk_size / 2);
+        quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+        quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
 #else
 #error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
 #endif
@@ -1848,14 +1826,95 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
          ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
+static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
+                              uint32_t nrows_per_thread, uint32_t dst_stride) {
+
+    uint64_t t1 = HAP_perf_get_qtimer_count();
+
+    const uint32_t ne0 = src->ne[0];
+    const uint32_t ne1 = src->ne[1];
+    const uint32_t ne2 = src->ne[2];
+    const uint32_t ne3 = src->ne[3];
+
+    const uint32_t nrows = ne1 * ne2 * ne3;                             // total n_rows
+
+    const uint32_t ir_first = nrows_per_thread * ith;                   // first row
+    const uint32_t ir_last  = MIN(ir_first + nrows_per_thread, nrows);  // last row
+
+    const size_t src_row_size = ne0 * sizeof(float);
+    const size_t src_stride   = src->nb[1];
+
+    uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
+    uint8_t * restrict dst_data = (uint8_t *) dst       + (dst_stride * ir_first);
+
+    for (uint32_t i = ir_first; i < ir_last; ++i) {
+        htp_l2fetch(src_data, 2, src_row_size, src_stride);
+        hvx_copy_fp16_fp32_au(dst_data, src_data, ne0);
+
+        dst_data += dst_stride;
+        src_data += src_stride;
+    }
+
+    uint64_t t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
+        ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// TODO just a plain copy that should be done via the DMA during the Op setup
+static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
+                              uint32_t nrows_per_thread, uint32_t dst_stride) {
+
+    uint64_t t1 = HAP_perf_get_qtimer_count();
+
+    const uint32_t ne0 = src->ne[0];
+    const uint32_t ne1 = src->ne[1];
+    const uint32_t ne2 = src->ne[2];
+    const uint32_t ne3 = src->ne[3];
+
+    const uint32_t nrows = ne1 * ne2 * ne3;                             // total n_rows
+
+    const uint32_t ir_first = nrows_per_thread * ith;                   // first row
+    const uint32_t ir_last  = MIN(ir_first + nrows_per_thread, nrows);  // last row
+
+    const size_t src_row_size = ne0 * sizeof(float);
+    const size_t src_stride   = src->nb[1];
+
+    uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
+    uint8_t * restrict dst_data = (uint8_t *) dst       + (dst_stride * ir_first);
+
+    for (uint32_t i = ir_first; i < ir_last; ++i) {
+        htp_l2fetch(src_data, 2, src_row_size, src_stride);
+        hvx_copy_fp16_au(dst_data, src_data, ne0);
+
+        dst_data += dst_stride;
+        src_data += src_stride;
+    }
+
+    uint64_t t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
+        ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
 static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) {
     struct htp_ops_context * octx = data;
     quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
 }
 
-// ** matmul callbacks for worker_pool
+static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+    quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
+}
+
+static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+    quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
+}
+
+// ** matmul/matvec callbacks for worker_pool
 
-static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
     struct htp_ops_context * octx = data;
 
     struct htp_matmul_type mt;
@@ -1863,11 +1922,10 @@ static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data
     mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
 
-    matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
-           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matvec_2d(&mt, octx, n, i);
 }
 
-static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
     struct htp_ops_context * octx = data;
 
     struct htp_matmul_type mt;
@@ -1875,11 +1933,10 @@ static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data
     mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
 
-    matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
-           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matmul_2d(&mt, octx, n, i);
 }
 
-static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
     struct htp_ops_context * octx = data;
 
     struct htp_matmul_type mt;
@@ -1887,11 +1944,10 @@ static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data
     mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
 
-    matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
-           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matvec_2d(&mt, octx, n, i);
 }
 
-static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
     struct htp_ops_context * octx = data;
 
     struct htp_matmul_type mt;
@@ -1899,11 +1955,10 @@ static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data
     mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
 
-    matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
-           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matmul_2d(&mt, octx, n, i);
 }
 
-static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
     struct htp_ops_context * octx = data;
 
     struct htp_matmul_type mt;
@@ -1911,11 +1966,10 @@ static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d
     mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
 
-    matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
-           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matvec_2d(&mt, octx, n, i);
 }
 
-static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
     struct htp_ops_context * octx = data;
 
     struct htp_matmul_type mt;
@@ -1923,14 +1977,49 @@ static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d
     mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
 
-    matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
-           octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matmul_2d(&mt, octx, n, i);
 }
 
-static void htp_matmul_f16_f32(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
     struct htp_ops_context * octx = data;
-    matmul_f16_f32(&octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
-                   octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+
+    struct htp_matmul_type mt;
+    mt.type        = "f16-f16";
+    mt.vec_dot     = vec_dot_f16_f16_aa;
+    mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
+
+    matvec_2d(&mt, octx, n, i);
+}
+
+static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "f16-f16";
+    mt.vec_dot     = vec_dot_f16_f16_aa;
+    mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
+
+    matmul_2d(&mt, octx, n, i);
+}
+
+static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "f16-f32";
+    mt.vec_dot     = vec_dot_f16_f32_uu;
+
+    matmul_4d(&mt, octx, n, i);
+}
+
+static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = data;
+
+    struct htp_matmul_type mt;
+    mt.type        = "f16-f16";
+    mt.vec_dot     = vec_dot_f16_f16_uu;
+
+    matmul_4d(&mt, octx, n, i);
 }
 
 // ** matmul-id callbacks for worker_pool
@@ -1943,8 +2032,7 @@ static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d
     mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
 
-    matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
-              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matvec_id(&mt, octx, n, i);
 }
 
 static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
@@ -1955,8 +2043,7 @@ static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d
     mt.vec_dot     = vec_dot_q4x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
 
-    matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
-              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matmul_id(&mt, octx, n, i);
 }
 
 static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
@@ -1967,8 +2054,7 @@ static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * d
     mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
 
-    matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
-              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matvec_id(&mt, octx, n, i);
 }
 
 static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
@@ -1979,8 +2065,7 @@ static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * d
     mt.vec_dot     = vec_dot_q8x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
 
-    matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
-              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matmul_id(&mt, octx, n, i);
 }
 
 static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
@@ -1991,8 +2076,7 @@ static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void
     mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
 
-    matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
-              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matvec_id(&mt, octx, n, i);
 }
 
 static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
@@ -2003,18 +2087,17 @@ static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void
     mt.vec_dot     = vec_dot_mxfp4x4x2_q8x4x2;
     mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
 
-    matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
-              &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+    matmul_id(&mt, octx, n, i);
 }
 
 // ** main matmul entry point
 
-int op_matmul(struct htp_ops_context * octx) {
-    const struct htp_tensor * src0 = &octx->src0;
-    const struct htp_tensor * src1 = &octx->src1;
-    struct htp_tensor *       dst  = &octx->dst;
+static inline bool htp_is_permuted(const struct htp_tensor * t) {
+    return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
+}
 
-    htp_matmul_preamble;
+int op_matmul(struct htp_ops_context * octx) {
+    htp_matmul_tensors_preamble;
 
     const char * op_type;
 
@@ -2038,9 +2121,9 @@ int op_matmul(struct htp_ops_context * octx) {
             op_type        = "q4x4x2-fp32";
             quant_job_func = htp_quantize_fp32_q8x4x2;
             if (src1_nrows > 1) {
-                matmul_job_func = htp_matmul_q4x4x2_q8x4x2;
+                matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2;
             } else {
-                matmul_job_func = htp_matvec_q4x4x2_q8x4x2;
+                matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2;
             }
 
             src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
@@ -2048,8 +2131,8 @@ int op_matmul(struct htp_ops_context * octx) {
             // Entire src1 tensor is placed into the VTCM
             // For other tensors we allocate N rows per thread, padded to HVX vector size
 
-            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
             octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
 
             // src0 spad is also used in dynamic quantizer to store padded src1 rows
@@ -2067,9 +2150,9 @@ int op_matmul(struct htp_ops_context * octx) {
             op_type        = "q8x4x2-fp32";
             quant_job_func = htp_quantize_fp32_q8x4x2;
             if (src1_nrows > 1) {
-                matmul_job_func = htp_matmul_q8x4x2_q8x4x2;
+                matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2;
             } else {
-                matmul_job_func = htp_matvec_q8x4x2_q8x4x2;
+                matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2;
             }
 
             src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
@@ -2077,8 +2160,8 @@ int op_matmul(struct htp_ops_context * octx) {
             // Entire src1 tensor is placed into the VTCM
             // For other tensors we allocate N rows per thread, padded to HVX vector size
 
-            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
             octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
 
             // src0 spad is also used in dynamic quantizer to store padded src1 rows
@@ -2096,9 +2179,9 @@ int op_matmul(struct htp_ops_context * octx) {
             op_type        = "mxfp4x4x2-f32";
             quant_job_func = htp_quantize_fp32_q8x4x2;
             if (src1_nrows > 1) {
-                matmul_job_func = htp_matmul_mxfp4x4x2_q8x4x2;
+                matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2;
             } else {
-                matmul_job_func = htp_matvec_mxfp4x4x2_q8x4x2;
+                matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2;
             }
 
             src1_row_size = q8x4x2_row_size(ne10);  // row size post quantization
@@ -2106,8 +2189,8 @@ int op_matmul(struct htp_ops_context * octx) {
             // Entire src1 tensor is placed into the VTCM
             // For other tensors we allocate N rows per thread, padded to HVX vector size
 
-            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
             octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
 
             // src0 spad is also used in dynamic quantizer to store padded src1 rows
@@ -2122,20 +2205,69 @@ int op_matmul(struct htp_ops_context * octx) {
             break;
 
         case HTP_TYPE_F16:
-            op_type         = "f16-f32";
-            quant_job_func  = NULL;  // htp_quantize_f32_f16;
-            matmul_job_func = htp_matmul_f16_f32;
-
-            // For all tensors we allocate N rows per thread, padded to HVX vector size
-            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size, 256);
-            octx->src1_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC1_NROWS * src1_row_size, 256);
-
-            octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
-            octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
-            octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
-
-            need_quant = false;
+            {
+                // Try optimized f16-f16 path first (src1 in VTCM)
+                const size_t f16_src1_row_size  = htp_round_up(ne10 * 2, 128);
+                const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256);
+                const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
+                const size_t f16_dst_spad_size  = htp_round_up(MM_SPAD_DST_NROWS  * dst_row_size, 256) * octx->n_threads;
+
+                const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
+
+                // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
+                // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
+                const bool is_batched  = (ne02 > 1) || (ne03 > 1);
+                const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
+
+                if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
+                    // Optimized path
+                    op_type        = "f16-f16";
+                    quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16;
+                    if (src1_nrows > 1) {
+                        matmul_job_func = htp_matmul_2d_f16_f16;
+                    } else {
+                        matmul_job_func = htp_matvec_2d_f16_f16;
+                    }
+
+                    src1_row_size = f16_src1_row_size; // row size post quantization
+
+                    octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+                    octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+                    octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
+
+                    octx->src1_spad.size = octx->src1_spad.size_per_thread;
+                    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+                    octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
+                } else {
+                    // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
+                    quant_job_func  = NULL;
+                    if (src1->type == HTP_TYPE_F32) {
+                        op_type         = "f16-f32";
+                        matmul_job_func = htp_matmul_4d_f16_f32;
+                    } else {
+                        op_type         = "f16-f16";
+                        matmul_job_func = htp_matmul_4d_f16_f16;
+                    }
+
+                    src1_row_size = nb11; // original row size in DDR
+
+                    octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+                    octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
+                    octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
+
+                    octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+                    octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
+                    octx->dst_spad.size  = octx->dst_spad.size_per_thread * octx->n_threads;
+
+                    // Init fastdiv for matmul_4d (supports broadcasting)
+                    octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
+                    octx->mm_div_ne1      = init_fastdiv_values(dst->ne[1]);
+                    octx->mm_div_r2       = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
+                    octx->mm_div_r3       = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
+
+                    need_quant = false;
+                }
+            }
             break;
 
         default:
@@ -2166,6 +2298,9 @@ int op_matmul(struct htp_ops_context * octx) {
     octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
     octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1);  // round up to even
 
+    octx->src0_spad.stride = src0_row_size_padded;
+    octx->src1_spad.stride = src1_row_size;
+
     if (need_quant) {
         // Run quant jobs
         const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
@@ -2185,12 +2320,9 @@ int op_matmul(struct htp_ops_context * octx) {
 // ** main matmul-id entry point
 
 int op_matmul_id(struct htp_ops_context * octx) {
-    const struct htp_tensor * src0 = &octx->src0;
-    const struct htp_tensor * src1 = &octx->src1;
-    const struct htp_tensor * ids  = &octx->src2;
-    struct htp_tensor *       dst  = &octx->dst;
+    htp_matmul_tensors_preamble;
 
-    htp_matmul_preamble;
+    struct htp_tensor * restrict ids = &octx->src2;
 
     const char * op_type;
 
@@ -2228,8 +2360,8 @@ int op_matmul_id(struct htp_ops_context * octx) {
 
             // Entire src1 tensor is placed into the VTCM
             // For other tensors we allocate N rows per thread, padded to HVX vector size
-            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
             octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
             octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
 
@@ -2257,8 +2389,8 @@ int op_matmul_id(struct htp_ops_context * octx) {
 
             // Entire src1 tensor is placed into the VTCM
             // For other tensors we allocate N rows per thread, padded to HVX vector size
-            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
             octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
             octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
 
@@ -2286,8 +2418,8 @@ int op_matmul_id(struct htp_ops_context * octx) {
 
             // Entire src1 tensor is placed into the VTCM
             // For other tensors we allocate N rows per thread, padded to HVX vector size
-            octx->dst_spad.size_per_thread  = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
-            octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+            octx->dst_spad.size_per_thread  = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+            octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
             octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
             octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
 
diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c
new file mode 100644 (file)
index 0000000..bdd64fc
--- /dev/null
@@ -0,0 +1,168 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+#    define FARF_HIGH 1
+#endif
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+#define set_rows_preamble \
+    const uint32_t ne00 = octx->src0.ne[0]; \
+    const uint32_t ne01 = octx->src0.ne[1]; \
+    const uint32_t ne02 = octx->src0.ne[2]; \
+    const uint32_t ne03 = octx->src0.ne[3]; \
+                                            \
+    const uint32_t ne10 = octx->src1.ne[0]; \
+    const uint32_t ne11 = octx->src1.ne[1]; \
+    const uint32_t ne12 = octx->src1.ne[2]; \
+                                            \
+    const uint32_t nb01 = octx->src0.nb[1]; \
+    const uint32_t nb02 = octx->src0.nb[2]; \
+    const uint32_t nb03 = octx->src0.nb[3]; \
+                                            \
+    const uint32_t nb10 = octx->src1.nb[0]; \
+    const uint32_t nb11 = octx->src1.nb[1]; \
+    const uint32_t nb12 = octx->src1.nb[2]; \
+                                            \
+    const uint32_t nb1 = octx->dst.nb[1];   \
+    const uint32_t nb2 = octx->dst.nb[2];   \
+    const uint32_t nb3 = octx->dst.nb[3];   \
+                                            \
+    const uint32_t ne1 = octx->dst.ne[1];   \
+                                            \
+    const uint32_t nr  = ne01;
+
+static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+    set_rows_preamble;
+
+    // parallelize by rows of src0
+    const uint32_t dr  = octx->src0_nrows_per_thread;
+    const uint32_t ir0 = dr * ith;
+    const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
+
+    const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
+
+    for (uint32_t i03 = 0; i03 < ne03; ++i03) {
+        for (uint32_t i02 = 0; i02 < ne02; ++i02) {
+            for (uint32_t i = ir0; i < ir1; ++i) {
+                const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
+                const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+                const uint32_t i10 = i;
+
+                const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
+
+                uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
+                if (i1 >= ne1) {
+                    // ignore invalid indices
+                    continue;
+                }
+
+                const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
+                const uintptr_t dst_ptr  = octx->dst.data  + i1*nb1 + i02*nb2  + i03*nb3;
+
+                // copy row
+                hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
+            }
+        }
+    }
+
+    return HTP_STATUS_OK;
+}
+
+static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+    set_rows_preamble;
+
+    // parallelize by rows of src0
+    const uint32_t dr  = octx->src0_nrows_per_thread;
+    const uint32_t ir0 = dr * ith;
+    const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
+
+    const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
+
+    for (uint32_t i03 = 0; i03 < ne03; ++i03) {
+        for (uint32_t i02 = 0; i02 < ne02; ++i02) {
+            for (uint32_t i = ir0; i < ir1; ++i) {
+                const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
+                const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+                const uint32_t i10 = i;
+
+                const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
+
+                uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
+                if (i1 >= ne1) {
+                    // ignore invalid indices
+                    continue;
+                }
+
+                const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
+                uint8_t*       dst_ptr  = (uint8_t *)       octx->dst.data  + i1*nb1 + i02*nb2  + i03*nb3;
+
+                hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00);
+            }
+        }
+    }
+
+    return HTP_STATUS_OK;
+}
+
+static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
+    set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
+}
+
+static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
+    set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
+}
+
+int op_set_rows(struct htp_ops_context * octx) {
+    set_rows_preamble;
+
+    if (octx->src0.type != HTP_TYPE_F32) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+        return HTP_STATUS_OK;
+    }
+
+    octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
+    octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
+
+    const uint32_t n_jobs = MIN(nr, octx->n_threads);
+    octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+
+    switch(octx->dst.type) {
+    case HTP_TYPE_F32:
+        worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
+        break;
+    case HTP_TYPE_F16:
+        worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
+        break;
+    default:
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    return HTP_STATUS_OK;
+}
index 5bf0cbf7922bbf9df48f657052eb9a4cb7cdeb8e..80d249a22c6de570c0892226ab8badc0a06db90d 100644 (file)
@@ -238,7 +238,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
                     hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
                                               (const uint8_t *) mp_f32, slope);
                 } else {
-                    hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale);
+                    hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
                     if (mp_f32) {
                         if (softmax_ctx->use_f16) {
                             for (int i = 0; i < ne00; ++i) {
@@ -258,7 +258,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
                     float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
                     float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
                     sum       = sum > 0.0 ? (1.0 / sum) : 1;
-                    hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum);
+                    hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
                 }
             }
         }
index bb7557b02526725b763c34a2b855bc9b6ace7588..8ed1e5b66198ed1db8e8a7188cd26f3e50b16df5 100644 (file)
@@ -83,6 +83,31 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
     }
 }
 
+static void scale_htp_f32(const float * restrict src,
+                          float * restrict dst,
+                          uint8_t * restrict spad,
+                          const uint32_t num_rows,
+                          const uint32_t row_elems,
+                          const size_t   row_size,
+                          int32_t *      op_params,
+                          int            opt_path) {
+    float scale = 0.f;
+    float bias  = 0.f;
+    memcpy(&scale, &op_params[0], sizeof(float));
+    memcpy(&bias,  &op_params[1], sizeof(float));
+
+    for (uint32_t ir = 0; ir < num_rows; ir++) {
+        const float * restrict src_local = src + (ir * row_elems);
+        float * restrict dst_local       = dst + (ir * row_elems);
+
+        if (ir + 1 < num_rows) {
+            htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
+        }
+
+        hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
+    }
+}
+
 static void rms_norm_htp_f32(const float * restrict src,
                              float * restrict dst,
                              uint8_t * restrict spad,
@@ -110,7 +135,7 @@ static void rms_norm_htp_f32(const float * restrict src,
             const float mean  = sum / row_elems;
             const float scale = 1.0f / sqrtf(mean + epsilon);
 
-            hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale);
+            hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
         }
     }
 }
@@ -162,6 +187,9 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
         case HTP_OP_RMS_NORM:
             rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
             break;
+        case HTP_OP_SCALE:
+            scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
+            break;
 
         default:
             break;
@@ -195,6 +223,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
             unary_op_func = unary_job_dispatcher_f32;
             op_type       = "rmsnorm-f32";
             break;
+        case HTP_OP_SCALE:
+            unary_op_func = unary_job_dispatcher_f32;
+            op_type       = "scale-f32";
+            break;
 
         default:
             FARF(ERROR, "Unsupported unary Op %u\n", octx->op);