]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
hexagon: general DMA and Binary Op fixes for large strides (llama/20918)
authorMax Krasnyansky <redacted>
Mon, 23 Mar 2026 22:33:49 +0000 (15:33 -0700)
committerGeorgi Gerganov <redacted>
Sat, 28 Mar 2026 11:39:09 +0000 (13:39 +0200)
* hex-dma: make chained dma the default to handle newer models

This also includes some new instrumentation that we can remove later.

* hexagon: add uint32 dump helper

* hexagon: use single-page VTCM allocation to avoid issues with large gather ops in ssm-conv

ssm-conv uses HVX gather instruction and that instruction cannot handle cases where the base+offset
spans page boundaries.

* hexagon: update ssm-conv to make base-addr compute a bit easier to read

* hex-dma: use 1d mode for reshaping, it supports sizes up to 24-bits (>16MB)

* hex-bin: fix incorrect stride logic

* hexagon: make sure repack buffs are dumped for verbose > 2

* hex-bin: consistently use dma_queue_push even for dummy dst transactions

* hex-dma: start using 2d-wide mode on v75 and up

The removes the need to deal with the 16-bit limitaion for the strides.

* hex-bin: cleanup kernel selection logic

* hex-bin: cleanup binary op core and fix transposed tensor handling

* snapdragon: update run-bench to use larger ubatch and fa-on

src/ggml-hexagon/ggml-hexagon.cpp
src/ggml-hexagon/htp/binary-ops.c
src/ggml-hexagon/htp/hex-dma.c
src/ggml-hexagon/htp/hex-dma.h
src/ggml-hexagon/htp/hex-dump.h
src/ggml-hexagon/htp/hmx-matmul-ops.c
src/ggml-hexagon/htp/hvx-utils.h
src/ggml-hexagon/htp/main.c
src/ggml-hexagon/htp/ssm-conv.c

index 8bcf5291c11fdda49f6c064eea97225f0e302261..9c1ce93cc69f4c4685664a73df41969434d91a8f 100644 (file)
@@ -461,7 +461,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) {
         d[7]          = x[i * 8 + 7].d;
     }
 
-    if (opt_verbose > 1) {
+    if (opt_verbose > 2) {
         for (int i = 0; i < nb; i++) {
             dump_packed_block_q4x4x2(y, i, k);
         }
@@ -480,7 +480,7 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) {
     const uint8_t * y_q = y + 0;              // quants first
     const uint8_t * y_d = y + qrow_size;      // then scales
 
-    if (opt_verbose > 1) {
+    if (opt_verbose > 2) {
         for (int i = 0; i < nb; i++) {
             dump_packed_block_q4x4x2(y, i, k);
         }
@@ -796,7 +796,7 @@ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) {
         d[7]          = x[i * 8 + 7].d;
     }
 
-    if (opt_verbose > 1) {
+    if (opt_verbose > 2) {
         for (int i = 0; i < nb; i++) {
             dump_packed_block_q8x4x2(y, i, k);
         }
@@ -814,7 +814,7 @@ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) {
     const uint8_t * y_q = y + 0;              // quants first
     const uint8_t * y_d = y + qrow_size;      // then scales
 
-    if (opt_verbose > 1) {
+    if (opt_verbose > 2) {
         for (int i = 0; i < nb; i++) {
             dump_packed_block_q8x4x2(y, i, k);
         }
@@ -1149,7 +1149,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k)
         e[7]        = x[i * 8 + 7].e;
     }
 
-    if (opt_verbose > 1) {
+    if (opt_verbose > 2) {
         for (int i = 0; i < nb; i++) {
             dump_packed_block_mxfp4x4x2(y, i, k);
         }
@@ -1168,7 +1168,7 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k)
     const uint8_t * y_q = y + 0;              // quants first
     const uint8_t * y_e = y + qrow_size;      // then scales
 
-    if (opt_verbose > 1) {
+    if (opt_verbose > 2) {
         for (int i = 0; i < nb; i++) {
             dump_packed_block_mxfp4x4x2(y, i, k);
         }
index ec90f22de52ee610503bf7ef41dba18984badeb6..1b0f97493bcf8dc60bc945d06970cc9c630001be 100644 (file)
 // Context for binary operations
 struct htp_binary_context {
     struct htp_ops_context * octx;
-    struct fastdiv_values dim1_div;
-    struct fastdiv_values dim2_div;
-    struct fastdiv_values dim12_div;
+
+    struct fastdiv_values src0_dim1_div; // ne01
+    struct fastdiv_values src0_dim2_div; // ne02
+    struct fastdiv_values src0_dim12_div;// ne03
 
     struct fastdiv_values src1_dim1_div; // ne11
     struct fastdiv_values src1_dim2_div; // ne12
     struct fastdiv_values src1_dim3_div; // ne13
 
-    uint32_t nrows_per_thread;
-    bool split_at_ne01;
-    bool split_at_ne02;
-
-    // Precomputed values
     uint32_t block_max;
+    uint32_t nrows_per_thread;
     size_t   src0_row_size_aligned;
     size_t   src1_row_size_aligned;
     size_t   dst_row_size_aligned;
-    uint32_t src1_fetch_rows; // 1 or block_max
-    uint32_t src1_dma_stride; // 0 or stride
+
+    bool split_at_ne01;
+    bool split_at_ne02;
 };
 
-#define htp_binary_preamble            \
+#define htp_binary_preamble                       \
     const struct htp_tensor * src0 = &octx->src0; \
     const struct htp_tensor * src1 = &octx->src1; \
     struct htp_tensor *       dst  = &octx->dst;  \
@@ -72,12 +70,11 @@ struct htp_binary_context {
     const uint32_t nb2 = dst->nb[2];   \
     const uint32_t nb3 = dst->nb[3];
 
-static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row,
-                                uint32_t ne01, uint32_t ne02) {
+static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, uint32_t ne01, uint32_t ne02) {
     uint32_t i03, i02, i01, rem;
-    i03 = fastdiv(ir, &bctx->dim12_div);
+    i03 = fastdiv(ir, &bctx->src0_dim12_div);
     rem = ir - i03 * (ne02 * ne01);
-    i02 = fastdiv(rem, &bctx->dim1_div);
+    i02 = fastdiv(rem, &bctx->src0_dim1_div);
     i01 = rem - i02 * ne01;
 
     uint32_t rows_left = end_row - ir;
@@ -191,6 +188,8 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
     const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
     if (start_row >= end_row) return;
 
+    FARF(HIGH, "binary-scalar: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
+
     uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
     uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
     size_t src0_spad_half    = octx->src0_spad.size_per_thread / 2;
@@ -204,9 +203,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
     for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
         uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
         uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
         rem = ir_prefetch - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
+        i02 = fastdiv(rem, &bctx->src0_dim1_div);
         i01 = rem - i02 * ne01;
 
         uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
@@ -215,7 +214,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
         uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
         uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
-        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
         dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
         ir_prefetch += current_block_size;
         spad_idx ^= 1;
@@ -229,9 +228,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
         uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
 
         uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir, &bctx->dim12_div);
+        i03 = fastdiv(ir, &bctx->src0_dim12_div);
         rem = ir - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
+        i02 = fastdiv(rem, &bctx->src0_dim1_div);
         i01 = rem - i02 * ne01;
 
         // src1 indices (broadcast/repeat)
@@ -255,9 +254,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
         if (ir_prefetch < end_row) {
              uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
              uint32_t p03, p02, p01, prem;
-             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
              prem = ir_prefetch - p03 * (ne02 * ne01);
-             p02 = fastdiv(prem, &bctx->dim1_div);
+             p02 = fastdiv(prem, &bctx->src0_dim1_div);
              p01 = prem - p02 * ne01;
              uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
 
@@ -282,6 +281,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
     const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
     if (start_row >= end_row) return;
 
+    FARF(HIGH, "binary-same-shape: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
+
     uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
     uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
     uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
@@ -297,9 +298,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
     for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
         uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
         uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
         rem = ir_prefetch - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
+        i02 = fastdiv(rem, &bctx->src0_dim1_div);
         i01 = rem - i02 * ne01;
 
         uint32_t i13 = (ne13 == 1) ? 0 : i03;
@@ -307,23 +308,23 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
         uint32_t i11 = (ne11 == 1) ? 0 : i01;
 
         uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
-        uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+        uint8_t * src1_curr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
         uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
 
         uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
         uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;
         uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
-        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
         dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
-        dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size);
+        dma_queue_push(q, dma_make_ptr(s1_spad, src1_curr), bctx->src1_row_size_aligned, nb11, row_size_bytes, current_block_size);
         ir_prefetch += current_block_size;
         spad_idx ^= 1;
     }
 
     for (uint32_t ir = start_row; ir < end_row; ) {
         uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
-        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * d_spad  = (uint8_t *) dma_queue_pop(q).src;
         uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
         uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;
 
@@ -335,9 +336,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
         }
 
         uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir, &bctx->dim12_div);
+        i03 = fastdiv(ir, &bctx->src0_dim12_div);
         rem = ir - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
+        i02 = fastdiv(rem, &bctx->src0_dim1_div);
         i01 = rem - i02 * ne01;
         uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
         dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
@@ -345,9 +346,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
         if (ir_prefetch < end_row) {
              uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
              uint32_t p03, p02, p01, prem;
-             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
              prem = ir_prefetch - p03 * (ne02 * ne01);
-             p02 = fastdiv(prem, &bctx->dim1_div);
+             p02 = fastdiv(prem, &bctx->src0_dim1_div);
              p01 = prem - p02 * ne01;
 
              uint32_t p13 = (ne13 == 1) ? 0 : p03;
@@ -358,7 +359,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
              uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
 
              dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
-             dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size);
+             dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, nb11, row_size_bytes, next_block_size);
 
              ir_prefetch += next_block_size;
         }
@@ -373,15 +374,17 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
     struct htp_ops_context * octx = bctx->octx;
     htp_binary_preamble;
 
-    const uint32_t src0_type = octx->src0.type;
+    const uint32_t src0_type  = octx->src0.type;
     const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
     const uint32_t total_rows = ne01 * ne02 * ne03;
-    const uint32_t start_row = bctx->nrows_per_thread * ith;
-    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    const uint32_t start_row  = bctx->nrows_per_thread * ith;
+    const uint32_t end_row    = MIN(start_row + bctx->nrows_per_thread, total_rows);
     if (start_row >= end_row) return;
 
+    FARF(HIGH, "binary-row-bcast: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
+
     uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
-    uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
+    uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
     uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
 
     size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
@@ -391,15 +394,14 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
     uint32_t ir_prefetch = start_row;
     int spad_idx = 0;
 
-    void * s1_ptr = (void *) src1_spad;
+    void * s1_ptr = (void *) src1_spad_base;
 
     for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
         uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
-        uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
-        rem = ir_prefetch - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
-        i01 = rem - i02 * ne01;
+        uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
+        uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
+        uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
+        uint32_t i01 = rem - i02 * ne01;
 
         uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
         uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
@@ -407,7 +409,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
         uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
         uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
-        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
         dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
         ir_prefetch += current_block_size;
         spad_idx ^= 1;
@@ -415,7 +417,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
 
     for (uint32_t ir = start_row; ir < end_row; ) {
         uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
-        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * d_spad  = (uint8_t *) dma_queue_pop(q).src;
         uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
 
         for (uint32_t r = 0; r < current_block_size; r++) {
@@ -425,21 +427,19 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
             COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
         }
 
-        uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir, &bctx->dim12_div);
-        rem = ir - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
-        i01 = rem - i02 * ne01;
+        uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
+        uint32_t rem = ir - i03 * (ne02 * ne01);
+        uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
+        uint32_t i01 = rem - i02 * ne01;
         uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
         dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
 
         if (ir_prefetch < end_row) {
              uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
-             uint32_t p03, p02, p01, prem;
-             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
-             prem = ir_prefetch - p03 * (ne02 * ne01);
-             p02 = fastdiv(prem, &bctx->dim1_div);
-             p01 = prem - p02 * ne01;
+             uint32_t p03  = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
+             uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
+             uint32_t p02  = fastdiv(prem, &bctx->src0_dim1_div);
+             uint32_t p01  = prem - p02 * ne01;
              uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
              dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
              ir_prefetch += next_block_size;
@@ -458,14 +458,16 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
     const uint32_t src0_type = octx->src0.type;
     const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
     const uint32_t total_rows = ne01 * ne02 * ne03;
-    const uint32_t start_row = bctx->nrows_per_thread * ith;
-    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    const uint32_t start_row  = bctx->nrows_per_thread * ith;
+    const uint32_t end_row    = MIN(start_row + bctx->nrows_per_thread, total_rows);
     if (start_row >= end_row) return;
 
+    FARF(HIGH, "binary-complex: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
+
     uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
     uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
-    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
-    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
+    size_t src0_spad_half    = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half     = octx->dst_spad.size_per_thread  / 2;
 
     dma_queue * q = octx->ctx->dma[ith];
     uint32_t ir_prefetch = start_row;
@@ -473,11 +475,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
 
     for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
         uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
-        uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
-        rem = ir_prefetch - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
-        i01 = rem - i02 * ne01;
+        uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
+        uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
+        uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
+        uint32_t i01 = rem - i02 * ne01;
 
         uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
         uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
@@ -485,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
         uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
         uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
-        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
         dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
         ir_prefetch += current_block_size;
         spad_idx ^= 1;
@@ -496,11 +497,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
         uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
         uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
 
-        uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir, &bctx->dim12_div);
-        rem = ir - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
-        i01 = rem - i02 * ne01;
+        uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
+        uint32_t rem = ir - i03 * (ne02 * ne01);
+        uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
+        uint32_t i01 = rem - i02 * ne01;
 
         for (uint32_t r = 0; r < current_block_size; r++) {
             uint32_t r_i01 = i01 + r;
@@ -521,11 +521,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
 
         if (ir_prefetch < end_row) {
              uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
-             uint32_t p03, p02, p01, prem;
-             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
-             prem = ir_prefetch - p03 * (ne02 * ne01);
-             p02 = fastdiv(prem, &bctx->dim1_div);
-             p01 = prem - p02 * ne01;
+             uint32_t p03  = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
+             uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
+             uint32_t p02  = fastdiv(prem, &bctx->src0_dim1_div);
+             uint32_t p01  = prem - p02 * ne01;
              uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
              dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
              ir_prefetch += next_block_size;
@@ -545,14 +544,16 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
     const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
     const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
     const uint32_t total_rows = ne01 * ne02 * ne03;
-    const uint32_t start_row = bctx->nrows_per_thread * ith;
-    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    const uint32_t start_row  = bctx->nrows_per_thread * ith;
+    const uint32_t end_row    = MIN(start_row + bctx->nrows_per_thread, total_rows);
     if (start_row >= end_row) return;
 
     uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
     uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
-    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
-    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
+    size_t src0_spad_half    = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half     = octx->dst_spad.size_per_thread  / 2;
+
+    FARF(HIGH, "binary-repeat: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned);
 
     dma_queue * q = octx->ctx->dma[ith];
     uint32_t ir_prefetch = start_row;
@@ -560,11 +561,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
 
     for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
         uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
-        uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
-        rem = ir_prefetch - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
-        i01 = rem - i02 * ne01;
+        uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
+        uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
+        uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
+        uint32_t i01 = rem - i02 * ne01;
 
         uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
         uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
@@ -572,7 +572,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
         uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
         uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
-        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0);
         dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
         ir_prefetch += current_block_size;
         spad_idx ^= 1;
@@ -583,11 +583,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
         uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
         uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
 
-        uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir, &bctx->dim12_div);
-        rem = ir - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
-        i01 = rem - i02 * ne01;
+        uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
+        uint32_t rem = ir - i03 * (ne02 * ne01);
+        uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
+        uint32_t i01 = rem - i02 * ne01;
 
         for (uint32_t r = 0; r < current_block_size; r++) {
             uint32_t r_i01 = i01 + r;
@@ -612,11 +611,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
 
         if (ir_prefetch < end_row) {
              uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
-             uint32_t p03, p02, p01, prem;
-             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
-             prem = ir_prefetch - p03 * (ne02 * ne01);
-             p02 = fastdiv(prem, &bctx->dim1_div);
-             p01 = prem - p02 * ne01;
+             uint32_t p03  = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
+             uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
+             uint32_t p02  = fastdiv(prem, &bctx->src0_dim1_div);
+             uint32_t p01  = prem - p02 * ne01;
              uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
              dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
              ir_prefetch += next_block_size;
@@ -646,6 +644,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
     const uint32_t nb02 = src0->nb[2];
     const uint32_t nb03 = src0->nb[3];
     const uint32_t nb11 = src1->nb[1]; // src1 row stride
+
     const uint32_t nb1 = dst->nb[1];
     const uint32_t nb2 = dst->nb[2];
     const uint32_t nb3 = dst->nb[3];
@@ -657,8 +656,8 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
 
     uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
     uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
-    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
-    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
+    size_t src0_spad_half    = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half     = octx->dst_spad.size_per_thread  / 2;
 
     dma_queue * q = octx->ctx->dma[ith];
     uint32_t ir_prefetch = start_row;
@@ -666,11 +665,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
 
     for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
         uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
-        uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
-        rem = ir_prefetch - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
-        i01 = rem - i02 * ne01;
+        uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
+        uint32_t rem = ir_prefetch - i03 * (ne02 * ne01);
+        uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
+        uint32_t i01 = rem - i02 * ne01;
 
         uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
         uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
@@ -678,7 +676,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
         uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
         uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
-        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), 0);
         dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
         ir_prefetch += current_block_size;
         spad_idx ^= 1;
@@ -689,11 +687,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
         uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
         uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
 
-        uint32_t i03, i02, i01, rem;
-        i03 = fastdiv(ir, &bctx->dim12_div);
-        rem = ir - i03 * (ne02 * ne01);
-        i02 = fastdiv(rem, &bctx->dim1_div);
-        i01 = rem - i02 * ne01;
+        uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div);
+        uint32_t rem = ir - i03 * (ne02 * ne01);
+        uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div);
+        uint32_t i01 = rem - i02 * ne01;
 
         for (uint32_t r = 0; r < current_block_size; r++) {
             uint32_t r_i01 = i01 + r; // linear within block since we split at ne01
@@ -712,11 +709,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
 
         if (ir_prefetch < end_row) {
              uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
-             uint32_t p03, p02, p01, prem;
-             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
-             prem = ir_prefetch - p03 * (ne02 * ne01);
-             p02 = fastdiv(prem, &bctx->dim1_div);
-             p01 = prem - p02 * ne01;
+             uint32_t p03  = fastdiv(ir_prefetch, &bctx->src0_dim12_div);
+             uint32_t prem = ir_prefetch - p03 * (ne02 * ne01);
+             uint32_t p02  = fastdiv(prem, &bctx->src0_dim1_div);
+             uint32_t p01  = prem - p02 * ne01;
              uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
              dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
              ir_prefetch += next_block_size;
@@ -739,40 +735,36 @@ static int execute_op_binary(struct htp_ops_context * octx) {
     const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
     const size_t src0_row_size = src0->ne[0] * elem_size;
     const size_t src1_row_size = src1->ne[0] * elem_size;
-    const size_t dst_row_size  = dst->ne[0] * elem_size;
+    const size_t dst_row_size  = dst->ne[0]  * elem_size;
 
-    // Align to VLEN
-    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
-    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
+    size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
     size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
+    size_t dst_row_size_aligned  = hex_round_up(dst_row_size,  VLEN);
 
     bool is_add_id = (octx->op == HTP_OP_ADD_ID);
     bool is_scalar = !is_add_id && (src1->ne[0] == 1);
 
-    // Determine which kernel we will use to alloc memory and dispatch
-    bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) &&
+    bool is_transposed = (src0->nb[1] < src0_row_size || src1->nb[1] < src1_row_size || dst->nb[1] < dst_row_size);
+
+    bool is_same_shape = !is_add_id && !is_scalar && !is_transposed &&
+               (src1->ne[0] == src0->ne[0] && src0->ne[0] % VLEN == 0) &&
                (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
                (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
                (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
 
-    bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
-    bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]);
-    bool use_repeat  = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]);
+    bool is_row_bcast = is_same_shape && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
+    bool is_complex   = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] == src0->ne[0]);
+    bool is_repeat    = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] != src0->ne[0]);
 
     size_t spad_row_total;
-    if (is_scalar) {
-        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
-    } else if (is_row_bcast) {
-        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
-    } else if (use_vector_same) {
+    if (is_same_shape) {
         spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);
-    } else if (is_add_id) {
-        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly
     } else {
         spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
     }
 
     size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);
+
     // Adjust for static src1 in row_bcast case
     if (is_row_bcast) {
         size_t needed_static = src1_row_size_aligned;
@@ -782,28 +774,26 @@ static int execute_op_binary(struct htp_ops_context * octx) {
     }
 
     if (rows_per_buffer < 1) {
-         FARF(ERROR, "binary: VTCM too small\n");
-         return HTP_STATUS_VTCM_TOO_SMALL;
+        FARF(ERROR, "binary: VTCM too small\n");
+        return HTP_STATUS_VTCM_TOO_SMALL;
     }
 
     octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;
     octx->dst_spad.size_per_thread  = rows_per_buffer * 2 * dst_row_size_aligned;
 
-    if (is_scalar || use_complex || use_repeat || is_add_id) {
-        octx->src1_spad.size_per_thread = 0;
-    } else if (is_row_bcast) {
+    if (is_add_id || is_scalar || is_complex || is_repeat || is_row_bcast) {
         octx->src1_spad.size_per_thread = 0;
     } else {
         octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;
     }
 
+    octx->dst_spad.size  = n_threads * octx->dst_spad.size_per_thread;
     octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
     if (is_row_bcast) {
         octx->src1_spad.size = src1_row_size_aligned;
     } else {
         octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
     }
-    octx->dst_spad.size  = n_threads * octx->dst_spad.size_per_thread;
 
     if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {
         return HTP_STATUS_VTCM_TOO_SMALL;
@@ -823,46 +813,37 @@ static int execute_op_binary(struct htp_ops_context * octx) {
     }
 
     struct htp_binary_context bctx;
-    bctx.octx = octx;
-    bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
-    bctx.block_max = rows_per_buffer;
+    bctx.octx                  = octx;
+    bctx.nrows_per_thread      = (src0_nrows + n_threads - 1) / n_threads;
+    bctx.block_max             = rows_per_buffer;
     bctx.src0_row_size_aligned = src0_row_size_aligned;
     bctx.src1_row_size_aligned = src1_row_size_aligned;
     bctx.dst_row_size_aligned  = dst_row_size_aligned;
 
-    bctx.dim1_div = init_fastdiv_values(src0->ne[1]);
-    bctx.dim2_div = init_fastdiv_values(src0->ne[2]);
-    bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
+    bctx.src0_dim1_div  = init_fastdiv_values(src0->ne[1]);
+    bctx.src0_dim2_div  = init_fastdiv_values(src0->ne[2]);
+    bctx.src0_dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
 
-    bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
-    bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
-    bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
+    bctx.src1_dim1_div  = init_fastdiv_values(src1->ne[1]);
+    bctx.src1_dim2_div  = init_fastdiv_values(src1->ne[2]);
+    bctx.src1_dim3_div  = init_fastdiv_values(src1->ne[3]);
 
     bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);
-    bool dst_contig_dim1  = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
+    bool dst_contig_dim1  = (dst->nb[2]  == src0->ne[1] * dst->nb[1]);
 
     bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);
-    bool dst_contig_dim2  = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
-
-    bctx.split_at_ne01 = (src0->ne[2] > 1) &&
-                         ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
+    bool dst_contig_dim2  = (dst->nb[3]  == src0->ne[2] * dst->nb[2]);
 
-    bctx.split_at_ne02 = (src0->ne[3] > 1) &&
-                         ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
-
-    // Precompute specific kernel parameters
-    if (use_vector_same) {
-        bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1];
-        bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer;
-    }
+    bctx.split_at_ne01 = (src0->ne[2] > 1) && ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
+    bctx.split_at_ne02 = (src0->ne[3] > 1) && ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
 
     worker_callback_t worker_func;
-    if (is_add_id) worker_func = binary_job_add_id;
-    else if (is_scalar) worker_func = binary_job_scalar;
-    else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
-    else if (use_vector_same) worker_func = binary_job_vector_same_shape;
-    else if (use_complex) worker_func = binary_job_vector_complex;
-    else worker_func = binary_job_element_repeat;
+    if (is_add_id)          worker_func = binary_job_add_id;
+    else if (is_scalar)     worker_func = binary_job_scalar;
+    else if (is_row_bcast)  worker_func = binary_job_vector_row_broadcast;
+    else if (is_same_shape) worker_func = binary_job_vector_same_shape;
+    else if (is_complex)    worker_func = binary_job_vector_complex;
+    else                    worker_func = binary_job_element_repeat;
 
     if (is_row_bcast) {
         dma_queue_pop(q);
index 44e1be40c5d7e8c20538326a8c6ad3f2eb0d2620..b66e2d2603ceab27f61d475c35f1805345991dd7 100644 (file)
@@ -31,8 +31,8 @@ dma_queue * dma_queue_create(size_t capacity) {
     q->capacity = capacity;
     q->idx_mask = capacity - 1;
 
-    q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t));
-    memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t));
+    q->desc = (dma_descriptor_2d *) memalign(64, capacity * sizeof(dma_descriptor_2d));
+    memset(q->desc, 0, capacity * sizeof(dma_descriptor_2d));
 
     q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr));
     memset(q->dptr, 0, capacity * sizeof(dma_ptr));
index 9811a07599fb22da28aa39d6a47653f09e03a85c..ff166cbcc7af5999afc5c2871ef1a323dbd0164a 100644 (file)
 extern "C" {
 #endif
 
+// Define the HW descriptor structs here since the ones in HexSDK are a bit out of date
+typedef struct dma_descriptor_1d_s {
+    void *   next;
+    uint32_t size:24;
+    uint32_t desc_size:2;
+    uint32_t dst_comp:1;
+    uint32_t src_comp:1;
+    uint32_t dst_bypass:1;
+    uint32_t src_bypass:1;
+    uint32_t order:1;
+    uint32_t done:1;
+    void *   src;
+    void *   dst;
+} dma_descriptor_1d;
+
+#if __HVX_ARCH__ < 75
+
+typedef struct dma_descriptor_2d_s {
+    void *   next;
+    uint32_t reserved0:24;
+    uint32_t desc_size:2;
+    uint32_t dst_comp:1;
+    uint32_t src_comp:1;
+    uint32_t dst_bypass:1;
+    uint32_t src_bypass:1;
+    uint32_t order:1;
+    uint32_t done:1;
+    void *   src;
+    void *   dst;
+    uint32_t desc_type:8;
+    uint32_t reserved1:24;
+    uint32_t row_size:16;
+    uint32_t nrows:16;
+    uint32_t src_stride:16;
+    uint32_t dst_stride:16;
+    uint32_t src_offset:16;
+    uint32_t dst_offset:16;
+} dma_descriptor_2d;
+
+#else
+
+typedef struct dma_descriptor_2d_s {
+    void *   next;
+    uint32_t dst_stride:24;
+    uint32_t desc_size:2;
+    uint32_t dst_comp:1;
+    uint32_t src_comp:1;
+    uint32_t dst_bypass:1;
+    uint32_t src_bypass:1;
+    uint32_t order:1;
+    uint32_t done:1;
+    void *   src;
+    void *   dst;
+    uint32_t desc_type:8;
+    uint32_t reserved0:24;
+    uint32_t row_size:24;
+    uint32_t nrows_lo:8;
+    uint32_t nrows_hi:8;
+    uint32_t src_stride:24;
+    uint32_t offset:24;
+    uint32_t reserved1:8;
+} dma_descriptor_2d;
+
+#endif
+
 typedef struct {
-    void *dst;
+    void       *dst;
     const void *src;
 } dma_ptr;
 
 typedef struct {
-    hexagon_udma_descriptor_type1_t * desc;  // descriptor pointers
-    hexagon_udma_descriptor_type1_t * tail;  // tail pointer
-    dma_ptr                         * dptr;  // dst/src pointers
-    uint32_t                          push_idx;
-    uint32_t                          pop_idx;
-    uint32_t                          capacity;
-    uint32_t                          idx_mask;
+    dma_descriptor_2d * desc;  // descriptor pointers
+    dma_descriptor_2d * tail;  // tail pointer
+    dma_ptr           * dptr;  // dst/src pointers
+    uint32_t            push_idx;
+    uint32_t            pop_idx;
+    uint32_t            capacity;
+    uint32_t            idx_mask;
 } dma_queue;
 
 dma_queue * dma_queue_create(size_t capacity);
@@ -59,71 +124,87 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src)
     return p;
 }
 
-static inline bool dma_queue_push(dma_queue * q,
-                                  dma_ptr     dptr,
-                                  size_t      dst_row_size,
-                                  size_t      src_row_size,
-                                  size_t      width, // width in bytes. number of bytes to transfer per row
-                                  size_t      nrows) {
+#if __HVX_ARCH__ < 73
+static const uint32_t dma_src_l2_bypass_on = 1;
+static const uint32_t dma_dst_l2_bypass_on = 0;
+#else
+static const uint32_t dma_src_l2_bypass_on = 1;
+static const uint32_t dma_dst_l2_bypass_on = 1;
+#endif
+
+static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) {
     if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
-        FARF(ERROR, "dma-push: queue full\n");
+        FARF(HIGH, "dma-push: queue full\n");
         return false;
     }
 
-    hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx];
+    dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx];
+    desc->next       = NULL;
+    desc->desc_size  = 0; // 1D mode
+    desc->src_bypass = dma_src_l2_bypass_on;
+    desc->dst_bypass = dma_dst_l2_bypass_on;
+    desc->order      = 1;
+    desc->done       = 0;
+    desc->src        = (void *) dptr.src;
+    desc->dst        = (void *) dptr.dst;
+    desc->size       = size;
+
+    q->dptr[q->push_idx] = dptr;
+
+    dmlink(q->tail, desc);
+    q->tail = (dma_descriptor_2d *) desc;
+
+    // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
+    q->push_idx = (q->push_idx + 1) & q->idx_mask;
+    return true;
+}
+
+static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) {
+    if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
+        FARF(HIGH, "dma-push: queue full\n");
+        return false;
+    }
+
+    dma_descriptor_2d * desc = &q->desc[q->push_idx];
 
     desc->next           = NULL;
-    desc->length         = 0;
-    desc->desctype       = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1;
-    desc->dstbypass      = 1;
-    desc->srcbypass      = 1;
-#if __HVX_ARCH__ >= 73
-    desc->dstbypass      = 1;
-    desc->srcbypass      = 1;
-#else
-    desc->dstbypass      = 0;
-    desc->srcbypass      = 1;
-#endif
-    desc->order          = 0;
-    desc->dstate         = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE;
+    desc->reserved0      = 0;
+    desc->reserved1      = 0;
+    desc->desc_size      = 1; // 2d mode
+    desc->src_bypass     = dma_src_l2_bypass_on;
+    desc->dst_bypass     = dma_dst_l2_bypass_on;
+    desc->src_comp       = 0;
+    desc->dst_comp       = 0;
+    desc->order          = 1;
+    desc->done           = 0;
+    desc->src_stride     = src_stride;
+    desc->dst_stride     = dst_stride;
     desc->src            = (void *) dptr.src;
     desc->dst            = (void *) dptr.dst;
-    desc->allocation     = 0;
-    desc->padding        = 0;
-    desc->roiwidth       = width;
-    desc->roiheight      = nrows;
-    desc->srcstride      = src_row_size;
-    desc->dststride      = dst_row_size;
-    desc->srcwidthoffset = 0;
-    desc->dstwidthoffset = 0;
+    desc->row_size       = row_size;
+
+#if __HVX_ARCH__ < 75
+    desc->desc_type      = 0; // 2d (16-bit) mode
+    desc->nrows          = nrows;
+    desc->src_offset     = 0;
+    desc->dst_offset     = 0;
+#else
+    desc->desc_type      = 9; // 2d (24-bit) mode
+    desc->nrows_lo       = (nrows & 0xff);
+    desc->nrows_hi       = (nrows >> 8);
+    desc->offset         = 0;
+#endif
 
     q->dptr[q->push_idx] = dptr;
 
     dmlink(q->tail, desc);
     q->tail = desc;
 
-    // FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src);
+    // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
     q->push_idx = (q->push_idx + 1) & q->idx_mask;
     return true;
 }
 
-static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q,
-                                              dma_ptr     dptr,
-                                              size_t      dst_row_size,
-                                              size_t      src_row_size,
-                                              size_t      nrows) {
-    return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows);
-}
-
-
-static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q,
-                                              dma_ptr     dptr,
-                                              size_t      dst_row_size,
-                                              size_t      src_row_size,
-                                              size_t      nrows) {
-    return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
-}
-
 static inline dma_ptr dma_queue_pop(dma_queue * q) {
     dma_ptr dptr  = { NULL };
 
@@ -131,12 +212,12 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
         return dptr;
     }
 
-    hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx];
+    dma_descriptor_2d * desc = &q->desc[q->pop_idx];
 
     // Wait for desc to complete
     while (1) {
         dmpoll();
-        if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) {
+        if (desc->done) {
             break;
         }
         // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
@@ -175,86 +256,62 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) {
     return q->capacity;
 }
 
-// ---------------------------------------------------------------------------
-// Overflow-safe DMA push: all UDMA type1 descriptor fields (roiwidth,
-// roiheight, srcstride, dststride) are 16-bit, max 65535.  This helper
-// transparently handles values that exceed the 16-bit limit and submits
-// chained DMA transtions.
-//
-// Case 1 (fast path): all params fit in 16 bits -> direct dma_queue_push.
-// Case 2 (contiguous block): width == srcstride == dststride.  Reshape the
-//   flat transfer into a 2D descriptor with sub_width <= 65535.  Produces a
-//   single descriptor, preserving async DMA behavior.
-// Case 3 (stride overflow): srcstride or dststride > 65535.  Issue rows
-//   one at a time.  The first N-1 rows are pushed+popped synchronously;
-//   the last row is left async so the caller can pop it.
-// ---------------------------------------------------------------------------
-#define UDMA_MAX_FIELD_VAL 65535u
-
-static inline bool dma_queue_push_chained(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t width, size_t nrows) {
-    // Fast path: everything fits in 16 bits.
-    if (__builtin_expect(
-            width      <= UDMA_MAX_FIELD_VAL &&
-            nrows      <= UDMA_MAX_FIELD_VAL &&
-            src_stride <= UDMA_MAX_FIELD_VAL &&
-            dst_stride <= UDMA_MAX_FIELD_VAL, 1)) {
-        return dma_queue_push(q, dptr, dst_stride, src_stride, width, nrows);
-    }
+#if __HVX_ARCH__ < 75
 
-    // Case 2: contiguous block (width == src_stride == dst_stride).
-    // Reshape total bytes into sub_width * sub_nrows where sub_width <= 65535.
-    if (width == src_stride && width == dst_stride) {
-        size_t total = width * nrows;
+// Overflow-safe DMA push: all 2d descriptor fields (row_size, nrows, src_stride, dst_stride) are 16-bit, max 65535.
+// This version transparently handles values that exceed the 16-bit limit and submits chained DMA transtions.
 
-        // Pick the largest 128-byte-aligned sub_width that divides total evenly.
-        size_t sub_width = UDMA_MAX_FIELD_VAL & ~(size_t)127;  // 65408
-        while (sub_width > 0 && total % sub_width != 0) {
-            sub_width -= 128;
-        }
-        if (sub_width == 0) {
-            // Fallback: use original width (must fit) with adjusted nrows.
-            // This shouldn't happen for 128-aligned DMA sizes.
-            sub_width = width;
-        }
-        size_t sub_nrows = total / sub_width;
-
-        // Handle sub_nrows > 65535 by issuing chunked descriptors.
-        const uint8_t *src = (const uint8_t *)dptr.src;
-        uint8_t       *dst = (uint8_t *)dptr.dst;
-        size_t rows_done = 0;
-        while (rows_done < sub_nrows) {
-            size_t chunk = sub_nrows - rows_done;
-            if (chunk > UDMA_MAX_FIELD_VAL) chunk = UDMA_MAX_FIELD_VAL;
-
-            dma_ptr p = dma_make_ptr(dst + rows_done * sub_width, src + rows_done * sub_width);
-            if (!dma_queue_push(q, p, sub_width, sub_width, sub_width, chunk))
-                return false;
+#define DMA_MAX_FIELD_VAL 65535u
 
-            rows_done += chunk;
-            // Complete all chunks without waiting except the last one, so the
-            // caller's single dma_queue_pop drains the final descriptor.
-            if (rows_done < sub_nrows)
-                dma_queue_pop_nowait(q);
-        }
-        return true;
+static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) {
+    // Fast path: everything fits in 16 bits
+    if (nrows == 0 || __builtin_expect(
+            row_size   <= DMA_MAX_FIELD_VAL &&
+            nrows      <= DMA_MAX_FIELD_VAL &&
+            src_stride <= DMA_MAX_FIELD_VAL &&
+            dst_stride <= DMA_MAX_FIELD_VAL, 1)) {
+        return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows);
     }
 
-    // Case 3: stride overflow — fall back to row-by-row.
+    // Contiguous block
+    // Use 1d DMA mode which supports sizes up to 24-bits (16MB)
+    if (nrows == 1 || (row_size == src_stride && row_size == dst_stride)) {
+        size_t total = row_size * nrows;
+        return dma_queue_push_single_1d(q, dptr, total);
+    }
+
+    // Stride overflow — fall back to row-by-row.
     {
-        const uint8_t *src = (const uint8_t *)dptr.src;
-        uint8_t       *dst = (uint8_t *)dptr.dst;
+        const uint8_t *src = (const uint8_t *) dptr.src;
+        uint8_t       *dst = (uint8_t *)       dptr.dst;
         for (size_t r = 0; r < nrows; ++r) {
-          dma_ptr p = dma_make_ptr(dst + r * dst_stride,
-                                   src + r * src_stride);
-          if (!dma_queue_push(q, p, 0, 0, width, 1))
-            return false;
-          if (r + 1 < nrows)
-            dma_queue_pop_nowait(q);
+            dma_ptr p = dma_make_ptr(dst + r * dst_stride, src + r * src_stride);
+            if (!dma_queue_push_single_1d(q, p, row_size))
+                return false;
+            if (r + 1 < nrows)
+                dma_queue_pop(q);
         }
         return true;
     }
 }
 
+#else // HVX_ARCH >= 75
+
+static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) {
+    // On v75 and up we always use 2d 24-bit mode
+    return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows);
+}
+
+#endif
+
+static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) {
+    return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows);
+}
+
+static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) {
+    return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
index e3badb57f92e1fa0b4497a563a4c69a7469dadb3..19d173c223285fed6f019508758178caadf9e0c9 100644 (file)
@@ -21,6 +21,15 @@ static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t
     FARF(HIGH, "%s\n", str);
 }
 
+static inline void hex_dump_uint32_line(char * pref, const uint32_t * x, uint32_t n) {
+    char str[1024], *p = str, *p_end = str + sizeof(str);
+    p += snprintf(p, p_end - p, "%s: ", pref);
+    for (int i = 0; i < n; i++) {
+        p += snprintf(p, p_end - p, "%u, ", (unsigned int) x[i]);
+    }
+    FARF(HIGH, "%s\n", str);
+}
+
 static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
     char str[1024], *p = str, *p_end = str + sizeof(str);
     p += snprintf(p, p_end - p, "%s: ", pref);
index c703a049426a36e68062080908f1871875cfb497..a56356bee9f521c5bf44d4f54364461a18cbed94 100644 (file)
@@ -727,7 +727,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
                     if (use_dma_activation) {
                         const size_t row_bytes    = (size_t) params->k * sizeof(float);
                         const size_t stride_bytes = (size_t) params->act_stride * sizeof(float);
-                        dma_queue_push_chained(ctx->dma[0],
+                        dma_queue_push(ctx->dma[0],
                                           dma_make_ptr(vtcm_f32_act, activation_chunk),
                                           row_bytes, stride_bytes, row_bytes, n_rows);
                         dma_queue_pop(ctx->dma[0]);
@@ -747,7 +747,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
 
                 {
                     const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols);
-                    dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, weight_group),
+                    dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group),
                                       fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first);
                 }
 
@@ -765,7 +765,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu
                             const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols);
                             const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride;
 
-                            dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
+                            dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
                                               fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next);
                         }
 
@@ -891,7 +891,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
             if (use_dma_activation) {
                 const size_t row_bytes    = (size_t) k * sizeof(float);
                 const size_t stride_bytes = (size_t) act_stride * sizeof(float);
-                dma_queue_push_chained(ctx->dma[0],
+                dma_queue_push(ctx->dma[0],
                                   dma_make_ptr(vtcm_f32_act, activation_chunk),
                                   row_bytes, stride_bytes, row_bytes, n_rows);
                 dma_queue_pop(ctx->dma[0]);
@@ -916,7 +916,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
         {
             const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
 
-            dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight),
+            dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight),
                               fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first);
         }
 
@@ -933,7 +933,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co
                     const size_t n_cols_next       = hex_smin(n - nc_next, n_chunk_n_cols);
                     const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride;
 
-                    dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
+                    dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk),
                                       fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next);
                 }
 
@@ -1104,7 +1104,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
             // because UDMA roiwidth is 16-bit and total size can exceed 65535.
             {
                 const size_t n_cols_first = hex_smin(n, n_chunk_n_cols);
-                dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first);
+                dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first);
             }
 
             for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) {
@@ -1120,7 +1120,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
 
                         const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride;
 
-                        dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next);
+                        dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next);
                     }
 
                     // Dequant + vscatter writes directly to [K, N] transposed tiles.
@@ -1173,7 +1173,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
             {
                 // Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow.
                 const uint8_t *qweight_chunk_A0 = permuted_weight;
-                dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0);
+                dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0);
             }
 
             {
@@ -1191,7 +1191,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
                 const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols);
                 if (1 < n_chunk_cnt) {
                     const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride;
-                    dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
+                    dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1);
                 }
 
                 // C0
@@ -1218,7 +1218,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds
                 // issue A_{i+2}
                 if (i + 2 < n_chunk_cnt) {
                     const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride;
-                    dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
+                    dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2);
                 }
 
                 // wait for HMX (C_{i}) -- C_{i} is done
@@ -1443,7 +1443,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
                 {
                     const float *activation_block = x + mr * k + kk;
 
-                    dma_queue_push_chained(ctx->dma[0],
+                    dma_queue_push(ctx->dma[0],
                                      dma_make_ptr(vtcm_scratch1, activation_block),
                                      k_blk_sz * sizeof(float),
                                      k * sizeof(float),
@@ -1472,10 +1472,10 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict
                     s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE;
 
                     // 2D DMA: quants sub-range
-                    dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off),
+                    dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off),
                                       s.dst_stride, s.src_stride, s.quant_width, s.n_rows);
                     // 2D DMA: scales sub-range
-                    dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off),
+                    dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off),
                                       s.dst_stride, s.src_stride, s.scale_width, s.n_rows);
                 }
                 TIMER_STOP(fetch);
index 083437987946dedcded18b07540db43bda6dd7a1..a518ad37331dee7040ea60497bbd9f1b7caf40e4 100644 (file)
 #include "hvx-div.h"
 #include "hvx-base.h"
 
-#ifndef GATHER_TYPE
-#    if defined(__hexagon__)
-#        define GATHER_TYPE(_a) (intptr_t) _a
-#    else
-#        define GATHER_TYPE(_a) (HVX_Vector *) _a
-#    endif
-#endif
-
 #endif /* HVX_UTILS_H */
index ef9cba8ecc15fecfb6e42c3478ac7c39ea354d9d..70ba9f9f4fe45581bb99ba4e896902beb3872050 100644 (file)
@@ -214,7 +214,7 @@ static int vtcm_alloc(struct htp_context * ctx) {
     HAP_compute_res_attr_init(&attr);
     HAP_compute_res_attr_set_serialize(&attr, 0);
     HAP_compute_res_attr_set_cache_mode(&attr, 1);
-    HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size);
+    HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size); // single page
     HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx);
     HAP_compute_res_attr_set_hmx_param(&attr, 1);
 
@@ -319,7 +319,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
     ctx->n_threads = n_hvx;
     for (int i = 0; i < ctx->n_threads; i++) {
         // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541
-        ctx->dma[i] = dma_queue_create(64);
+        ctx->dma[i] = dma_queue_create(128);
     }
 
     // init worker pool
index b3c1ef9572e4c638b51eafebbf02d88377385c81..6b035810d57ae25c0f0db06c8b77e490f56a4330 100644 (file)
@@ -151,7 +151,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
     const int dr = scctx->nrows_per_thread;
     const uint32_t ir0 = dr * ith;
     const uint32_t ir1 = MIN(ir0 + dr, d_inner);
-    const int      ir  = ir1 - ir0;
+    const uint32_t ir  = ir1 - ir0;
 
     if (ir0 >= ir1) {
         return;  // No work for this thread
@@ -205,10 +205,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
                 HVX_Vector acc_vec = Q6_V_vsplat_R(0);
 
                 for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
-                    Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
-                                     src0_gather_len, (*(const HVX_Vector *) src0_offsets));
-                    Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
-                                     src1_gather_len, (*(const HVX_Vector *) src1_offsets));
+                    uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]);
+                    uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc)  * sizeof(float);
+                    Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets));
+                    Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets));
 
                     HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
                     acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);
@@ -222,10 +222,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void
                 HVX_Vector acc_vec = Q6_V_vsplat_R(0);
 
                 for (uint32_t i0 = 0; i0 < d_conv; ++i0) {
-                    Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])),
-                                     src0_gather_len, (*(const HVX_Vector *) src0_offsets));
-                    Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)),
-                                     src1_gather_len, (*(const HVX_Vector *) src1_offsets));
+                    uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]);
+                    uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc)  * sizeof(float);
+                    Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets));
+                    Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets));
 
                     HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec);
                     acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod);