]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
hexagon: Add ARGSORT, DIV, SQR, SQRT, SUM_ROWS, GEGLU (llama/19406)
authorMax Krasnyansky <redacted>
Wed, 11 Feb 2026 07:21:12 +0000 (23:21 -0800)
committerGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 22:20:18 +0000 (00:20 +0200)
* hexagon: add ARGSORT op

Co-authored-by: Yarden Tal <redacted>
* hexagon: argsort reject tensors with huge rows for now

* Adding support for DIV,SQR,SQRT,SUM_ROWS ops in hexagon backend

* hexagon : Add GEGLU op

* hexagon: fix editor config check

* hexagon: rewrite and optimize binary ops ADD/SUB/MUL/DIV/ADD_ID to use DMA

---------

Co-authored-by: Yarden Tal <redacted>
Co-authored-by: Manohara Hosakoppa Krishnamurthy <redacted>
17 files changed:
src/ggml-hexagon/ggml-hexagon.cpp
src/ggml-hexagon/htp/CMakeLists.txt
src/ggml-hexagon/htp/act-ops.c
src/ggml-hexagon/htp/argsort-ops.c [new file with mode: 0644]
src/ggml-hexagon/htp/binary-ops.c
src/ggml-hexagon/htp/htp-msg.h
src/ggml-hexagon/htp/htp-ops.h
src/ggml-hexagon/htp/hvx-arith.h
src/ggml-hexagon/htp/hvx-base.h
src/ggml-hexagon/htp/hvx-copy.h
src/ggml-hexagon/htp/hvx-div.h [new file with mode: 0644]
src/ggml-hexagon/htp/hvx-sigmoid.h
src/ggml-hexagon/htp/hvx-sqrt.h
src/ggml-hexagon/htp/hvx-utils.h
src/ggml-hexagon/htp/main.c
src/ggml-hexagon/htp/sum-rows-ops.c [new file with mode: 0644]
src/ggml-hexagon/htp/unary-ops.c

index 4f0a1620fbf32a429b2b8203322ffef5ac1733b6..54f9986498f49d45af1cd00e7b6fb3be8958b8d2 100644 (file)
@@ -1935,11 +1935,6 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
         return false;
     }
 
-    // TODO: add support for non-contigiuos tensors
-    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
-        return false;
-    }
-
     return true;
 }
 
@@ -1991,6 +1986,25 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
     return true;
 }
 
+static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * dst  = op;
+
+    if (!hex_supported_src0_type(src0->type)) {
+        return false;
+    }
+    if (!hex_supported_dst_type(dst->type)) {
+        return false;
+    }
+
+    // TODO: add support for non-contigiuos tensors
+    if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
+        return false;
+    }
+
+    return true;
+}
+
 static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
                                                const struct ggml_tensor *          op) {
     const struct ggml_tensor * src0 = op->src[0];
@@ -2111,6 +2125,26 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session *
     return true;
 }
 
+static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0]; // values
+    const struct ggml_tensor * dst  = op;         // indices
+
+    if (src0->type != GGML_TYPE_F32) {
+        return false;
+    }
+
+    if (dst->type != GGML_TYPE_I32) {
+        return false;
+    }
+
+    if (src0->ne[0] > (16*1024)) {
+        // reject tensors with huge rows for now
+        return false;
+    }
+
+    return true;
+}
+
 static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
     const int32_t * op_params = &op->op_params[0];
 
@@ -2278,6 +2312,9 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
         case GGML_OP_SUB:
             req->op = HTP_OP_SUB;
             break;
+        case GGML_OP_DIV:
+            req->op = HTP_OP_DIV;
+            break;
         default:
             GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
             break;
@@ -2316,6 +2353,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer *
     return n_bufs;
 }
 
+static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    req->op = HTP_OP_ARGSORT;
+    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
 template <bool _is_src0_constant>
 static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
     switch (t->op) {
@@ -2370,6 +2418,16 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
             supported = true;
             break;
 
+        case GGML_OP_SQR:
+            req->op   = HTP_OP_SQR;
+            supported = true;
+            break;
+
+        case GGML_OP_SQRT:
+            req->op   = HTP_OP_SQRT;
+            supported = true;
+            break;
+
         case GGML_OP_UNARY:
             if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
                 req->op   = HTP_OP_UNARY_SILU;
@@ -2387,6 +2445,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
             } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
                 req->op   = HTP_OP_GLU_SWIGLU_OAI;
                 supported = true;
+            } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
+                req->op   = HTP_OP_GLU_GEGLU;
+                supported = true;
             }
             break;
 
@@ -2411,6 +2472,17 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
     return n_bufs;
 }
 
+static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+    req->op = HTP_OP_SUM_ROWS;
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
 static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
     memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
     req->op = HTP_OP_ROPE;
@@ -2519,6 +2591,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
             case GGML_OP_MUL:
             case GGML_OP_ADD:
             case GGML_OP_SUB:
+            case GGML_OP_DIV:
                 ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
                 break;
             case GGML_OP_ADD_ID:
@@ -2528,6 +2601,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
             case GGML_OP_SCALE:
                 ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
                 break;
+            case GGML_OP_SQR:
+            case GGML_OP_SQRT:
+                ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+                break;
+            case GGML_OP_SUM_ROWS:
+                ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
+                break;
             case GGML_OP_UNARY:
                 if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
                         (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
@@ -2536,7 +2616,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 break;
             case GGML_OP_GLU:
                 if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
-                        (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
+                        (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
+                        (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
                     ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
                 }
                 break;
@@ -2564,6 +2645,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
                 break;
 
+            case GGML_OP_ARGSORT:
+                ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
+                break;
+
             default:
                 GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
         }
@@ -2916,6 +3001,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
         case GGML_OP_MUL:
         case GGML_OP_ADD:
         case GGML_OP_SUB:
+        case GGML_OP_DIV:
             supp = ggml_hexagon_supported_binary(sess, op);
             break;
 
@@ -2928,6 +3014,15 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
             supp = ggml_hexagon_supported_unary(sess, op);
             break;
 
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+            supp = ggml_hexagon_supported_unary(sess, op);
+            break;
+
+        case GGML_OP_SUM_ROWS:
+            supp = ggml_hexagon_supported_sum_rows(sess, op);
+            break;
+
         case GGML_OP_SOFT_MAX:
             supp = ggml_hexagon_supported_softmax(sess, op);
             break;
@@ -2943,7 +3038,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
         case GGML_OP_GLU:
             {
                 const auto glu_op = ggml_get_glu_op(op);
-                if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) {
+                if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
                     supp = ggml_hexagon_supported_activations(sess, op);
                 }
                 break;
@@ -2968,6 +3063,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
             supp = ggml_hexagon_supported_cpy(sess, op);
             break;
 
+        case GGML_OP_ARGSORT:
+            supp = ggml_hexagon_supported_argsort(sess, op);
+            break;
+
         default:
             break;
     }
index e8ef203045c3fd1b365a64f64c752de6d3daf81e..2c23b60da3d1328875d3a516c411715394edb3fa 100644 (file)
@@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
 include_directories(
     ${HEXAGON_SDK_ROOT}/incs
     ${HEXAGON_SDK_ROOT}/incs/stddef
+    ${CMAKE_CURRENT_SOURCE_DIR}/../../../include
     ${CMAKE_CURRENT_SOURCE_DIR}/../..
     ${CMAKE_CURRENT_SOURCE_DIR}/..
     ${CMAKE_CURRENT_SOURCE_DIR}
@@ -21,6 +22,7 @@ add_library(${HTP_LIB} SHARED
     matmul-ops.c
     binary-ops.c
     unary-ops.c
+    sum-rows-ops.c
     softmax-ops.c
     act-ops.c
     rope-ops.c
@@ -28,6 +30,7 @@ add_library(${HTP_LIB} SHARED
     set-rows-ops.c
     get-rows-ops.c
     cpy-ops.c
+    argsort-ops.c
 )
 
 target_compile_definitions(${HTP_LIB} PRIVATE
index c3daf5adb2e8b2d2aab06779ebda05fe3f098dfc..950d836ad349d52f148a9bea5d9cf2da303f1200 100644 (file)
@@ -410,7 +410,7 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
             // gelu = x * sigmoid(1.702 * x) // current implementation
             hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
             hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
-            hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
+            hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
         }
 
         dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -516,7 +516,7 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
 
             // silu = x * sigmoid(x)
             hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
-            hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
+            hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
         }
 
         dma_queue_push_vtcm_to_ddr(dma_queue,
@@ -541,6 +541,143 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
          ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
 }
 
+static const float GELU_COEF_A     = 0.044715f;
+static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
+                                       const struct htp_tensor * src1,
+                                       struct htp_tensor *       dst,
+                                       const int32_t *           op_params,
+                                       struct htp_spad *         src0_spad,
+                                       struct htp_spad *         src1_spad,
+                                       struct htp_spad *         dst_spad,
+                                       uint32_t                  nth,
+                                       uint32_t                  ith,
+                                       uint32_t                  src0_nrows_per_thread,
+                                       dma_queue *               dma_queue) {
+    htp_act_preamble3;
+
+    size_t src0_row_size = nb01;
+    size_t src1_row_size = nb11;
+    size_t dst_row_size  = nb1;
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return;
+    }
+
+    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
+    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
+    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
+
+    const bool src1_valid = src1->ne[0];
+    const int  nc         = (src1_valid) ? ne00 : ne00 / 2;
+    if (!src1_valid) {
+        const int32_t swapped = op_params[1];
+        data_src1             = data_src0;
+        src1_row_size         = src0_row_size;
+
+        const size_t nc_in_bytes = nc * SIZEOF_FP32;
+        data_src0 += swapped ? nc_in_bytes : 0;
+        data_src1 += swapped ? 0 : nc_in_bytes;
+    }
+
+    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+    const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
+    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
+
+    uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
+    uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
+    uint8_t * restrict dst_spad_data  = dst_spad->data + (ith * dst_spad->size_per_thread);
+
+    // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
+    size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
+    size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
+    size_t dst_spad_half_size  = dst_spad->size_per_thread / 2;
+
+    const int BLOCK = src0_spad_half_size / src0_row_size_aligned;  // How many rows can we process in one block
+    if (BLOCK == 0) {
+        FARF(ERROR,
+             "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
+             src0_spad->size_per_thread, src0_row_size_aligned);
+        return;
+    }
+
+    // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
+    for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
+        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+        // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
+        dma_queue_push_vtcm_to_ddr(dma_queue,
+            dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
+            dst_row_size, dst_row_size_aligned, 0);
+
+        dma_queue_push_ddr_to_vtcm(dma_queue,
+            dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
+            src0_row_size_aligned, src0_row_size, block_size);
+        dma_queue_push_ddr_to_vtcm(dma_queue,
+            dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
+            src1_row_size_aligned, src1_row_size, block_size);
+    }
+
+    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+        const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+        float * dst_spad  = (float *) dma_queue_pop(dma_queue).src;
+        float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
+        float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
+
+        for (uint32_t ib = 0; ib < block_size; ib++) {
+            const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
+            const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
+            uint8_t *       dst_spad_ptr  = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
+
+            // geglu tanh implementation
+            // geglu(x, g) = gelu(x) * g
+            // gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
+            hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc);                       // res = x*x
+            hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc);   // res = res * GELU_COEF_A
+            hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc);          // res = res + 1.0f
+            hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc);       // res = res * x
+            hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
+            hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc);         // res = tanh(res)
+            hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc);           // res = res + 1.0f
+            hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc);       // res = res * x
+            hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc);          // res = res + 0.5f
+            hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc);       // res = res * g
+        }
+
+        dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
+                                   dst_row_size_aligned, block_size);
+
+        // prefetch N+2 loop iteration if any
+        const uint32_t pref_block = (ir + BLOCK * 2);
+        if (pref_block < src0_end_row) {
+            const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
+            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
+                                       src0_row_size_aligned, src0_row_size, pref_block_size);
+            dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
+                                       src1_row_size_aligned, src1_row_size, pref_block_size);
+        }
+    }
+
+    dma_queue_flush(dma_queue);
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+         ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
+         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
 static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
     struct htp_ops_context * octx = (struct htp_ops_context *) data;
     unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
@@ -559,6 +696,12 @@ static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
                                    &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
 }
 
+static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
+    struct htp_ops_context * octx = (struct htp_ops_context *) data;
+    glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
+                               &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
 static int execute_op_activations_f32(struct htp_ops_context * octx) {
     int err = HTP_STATUS_OK;
 
@@ -593,6 +736,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
             act_op_func = unary_gelu_f32;
             op_type     = "gelu-f32";
             break;
+
+        case HTP_OP_GLU_GEGLU:
+            act_op_func = glu_geglu_f32;
+            op_type     = "geglu-f32";
+            break;
         default:
             FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
             return HTP_STATUS_NO_SUPPORT;
diff --git a/src/ggml-hexagon/htp/argsort-ops.c b/src/ggml-hexagon/htp/argsort-ops.c
new file mode 100644 (file)
index 0000000..a4cee98
--- /dev/null
@@ -0,0 +1,281 @@
+#include <string.h>
+#include <stdlib.h>
+#include <math.h>
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "ggml.h"
+
+#include "hvx-utils.h"
+#include "hex-dma.h"
+
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+#ifndef MIN
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+struct htp_argsort_context {
+    struct htp_ops_context * octx;
+    uint32_t                 nrows_per_thread;
+};
+
+static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
+{
+    const HVX_Vector one  = Q6_V_vsplat_R(1);
+    const HVX_Vector zero = Q6_V_vzero();
+
+    HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
+    HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
+    HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
+    return hvx_vec_get_i32(sum) == 32;
+}
+
+// Sorts values and mirrors swaps to indices.
+static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
+    if (left >= right) return;
+
+    int pivot_idx = (left + right) / 2;
+    float pivot = values[pivot_idx];
+    int i = left;
+    int j = right;
+
+    HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
+    while (i <= j) {
+        // Vectorized scan for i
+        while (i <= j) {
+            // Check if we have at least one full vector
+            if (i + 32 <= j) {
+                HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
+                if (all_greater_f32(pivot_vec, vals_vec)) {
+                    // If all elements are < pivot, we can skip this whole block
+                    i += 32;
+                    continue;
+                }
+            }
+
+            // Scalar fallback / cleanup
+            if (values[i] < pivot) {
+                i++;
+            } else {
+                break;
+            }
+        }
+
+        // Vectorized scan for j
+        while (i <= j) {
+            if (j - 32 >= i) {
+                // Load 32 elements ending at j.
+                // Since we want `values[j] > pivot`, let's load from j-31 to j.
+                HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
+                if (all_greater_f32(vals_vec, pivot_vec)) {
+                    j -= 32;
+                    continue;
+                }
+            }
+
+            if (values[j] > pivot) {
+                j--;
+            } else {
+                break;
+            }
+        }
+
+        if (i <= j) {
+            float tmp_val = values[i];
+            values[i] = values[j];
+            values[j] = tmp_val;
+
+            int32_t tmp_idx = indices[i];
+            indices[i] = indices[j];
+            indices[j] = tmp_idx;
+            i++;
+            j--;
+        }
+    }
+
+    if (left < j) quicksort_values_indices_asc(values, indices, left, j);
+    if (i < right) quicksort_values_indices_asc(values, indices, i, right);
+}
+
+static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
+    if (left >= right) return;
+
+    int pivot_idx = (left + right) / 2;
+    float pivot = values[pivot_idx];
+    int i = left;
+    int j = right;
+
+    HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
+
+    while (i <= j) {
+        // Vectorized scan for i (values[i] > pivot)
+        while (i <= j) {
+            if (i + 32 <= j) {
+                HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
+                if (all_greater_f32(vals_vec, pivot_vec)) {
+                    i += 32;
+                    continue;
+                }
+            }
+
+            if (values[i] > pivot) {
+                i++;
+            } else {
+                break;
+            }
+        }
+
+        // Vectorized scan for j (values[j] < pivot)
+        while (i <= j) {
+            if (j - 32 >= i) {
+                HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
+                if (all_greater_f32(pivot_vec, vals_vec)) {
+                    j -= 32;
+                    continue;
+                }
+            }
+
+            if (values[j] < pivot) {
+                j--;
+            } else {
+                break;
+            }
+        }
+
+        if (i <= j) {
+            float tmp_val = values[i];
+            values[i] = values[j];
+            values[j] = tmp_val;
+
+            int32_t tmp_idx = indices[i];
+            indices[i] = indices[j];
+            indices[j] = tmp_idx;
+            i++;
+            j--;
+        }
+    }
+
+    if (left < j) quicksort_values_indices_desc(values, indices, left, j);
+    if (i < right) quicksort_values_indices_desc(values, indices, i, right);
+}
+
+static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
+    struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
+    struct htp_ops_context * octx = actx->octx;
+
+    // Unpack context
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * dst = &octx->dst;
+
+    // Scratchpad memory
+    uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
+
+    // Dimensions
+    uint32_t ne00 = src0->ne[0];
+    uint32_t ne01 = src0->ne[1];
+    uint32_t ne02 = src0->ne[2];
+    uint32_t ne03 = src0->ne[3];
+
+    uint32_t nb01 = src0->nb[1];
+    //uint32_t nb02 = src0->nb[2];
+    //uint32_t nb03 = src0->nb[3];
+
+    uint32_t nb1 = dst->nb[1];
+    //uint32_t nb2 = dst->nb[2];
+    //uint32_t nb3 = dst->nb[3];
+
+    // Sort order
+    enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
+
+    // Rows to process
+    uint32_t total_rows = ne01 * ne02 * ne03;
+    uint32_t rows_per_thread = actx->nrows_per_thread;
+    uint32_t start_row = rows_per_thread * i;
+    uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
+
+    // Scratchpad layout:
+    // We need space for one row of float data (values) and one row of int32 indices.
+    // values: ne00 * sizeof(float)
+    // indices: ne00 * sizeof(int32_t)
+    // Padded to 128 bytes.
+
+    size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
+    float * values_buf = (float *) spad;
+    int32_t * indices_buf = (int32_t *) (spad + values_size);
+
+    for (uint32_t r = start_row; r < end_row; r++) {
+        uint32_t src_offset = r * nb01;
+        uint32_t dst_offset = r * nb1;
+
+        uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
+        uint8_t * dst_ptr = (uint8_t *) dst->data  + dst_offset;
+
+        hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
+        hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
+
+        // Initialize indices
+        for (uint32_t j = 0; j < ne00; j++) {
+            indices_buf[j] = j;
+        }
+
+        // Sort values and mirror swaps to indices
+        if (order == GGML_SORT_ORDER_ASC) {
+            quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
+        } else {
+            quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
+        }
+
+        // Copy indices back to DDR
+        hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
+    }
+}
+
+int op_argsort(struct htp_ops_context * octx) {
+    // Check supported types
+    if (octx->src0.type != HTP_TYPE_F32) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    // Allocate scratchpad
+    // We need 1 row of float + 1 row of int32 per thread.
+    uint32_t ne00 = octx->src0.ne[0];
+    size_t values_size  = hex_round_up(ne00 * sizeof(float), 128);
+    size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
+    size_t spad_per_thread = values_size + indices_size;
+
+    // Make sure we round up to 256 for alignment requirements
+    spad_per_thread = hex_round_up(spad_per_thread, 256);
+
+    size_t total_spad_size = spad_per_thread * octx->n_threads;
+
+    if (octx->ctx->vtcm_size < total_spad_size) {
+        FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
+        return HTP_STATUS_VTCM_TOO_SMALL;
+    }
+
+    octx->src0_spad.data = octx->ctx->vtcm_base;
+    octx->src0_spad.size = total_spad_size;
+    octx->src0_spad.size_per_thread = spad_per_thread;
+
+    FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
+         octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
+         octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
+         octx->src0.data, octx->dst.data);
+
+    uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
+    uint32_t n_jobs = MIN(total_rows, octx->n_threads);
+
+    struct htp_argsort_context actx;
+    actx.octx = octx;
+    actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
+
+    // Run jobs
+    worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
+
+    return HTP_STATUS_OK;
+}
index de22afe460e3e076913c54c9c4bd3ba2a51c636c..00dbcf87986ac60fefe6eac3da66dba1009dc891 100644 (file)
 #include "htp-msg.h"
 #include "htp-ops.h"
 
-typedef void (*hvx_elemwise_f32_func)(uint8_t * data_dst, const uint8_t * src0, const uint8_t * src1, const uint32_t num_elems);
-
-static hvx_elemwise_f32_func func_table_HVX[]     = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 };
-static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f32_aa, hvx_sub_f32_aa };
+#ifndef MIN
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+// Context for binary operations
+struct htp_binary_context {
+    struct htp_ops_context * octx;
+    struct fastdiv_values dim1_div;
+    struct fastdiv_values dim2_div;
+    struct fastdiv_values dim12_div;
+
+    struct fastdiv_values src1_dim1_div; // ne11
+    struct fastdiv_values src1_dim2_div; // ne12
+    struct fastdiv_values src1_dim3_div; // ne13
+
+    uint32_t nrows_per_thread;
+    bool split_at_ne01;
+    bool split_at_ne02;
+
+    // Precomputed values
+    uint32_t block_max;
+    size_t   src0_row_size_aligned;
+    size_t   src1_row_size_aligned;
+    size_t   dst_row_size_aligned;
+    uint32_t src1_fetch_rows; // 1 or block_max
+    uint32_t src1_dma_stride; // 0 or stride
+};
 
 #define htp_binary_preamble            \
     const struct htp_tensor * src0 = &octx->src0; \
     const struct htp_tensor * src1 = &octx->src1; \
-    const struct htp_tensor * src2 = &octx->src2; \
     struct htp_tensor *       dst  = &octx->dst;  \
                                        \
     const uint32_t ne00 = src0->ne[0]; \
@@ -38,266 +60,696 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f3
     const uint32_t ne12 = src1->ne[2]; \
     const uint32_t ne13 = src1->ne[3]; \
                                        \
-    const uint32_t ne0 = dst->ne[0];   \
-    const uint32_t ne1 = dst->ne[1];   \
-    const uint32_t ne2 = dst->ne[2];   \
-    const uint32_t ne3 = dst->ne[3];   \
-                                       \
-    const uint32_t nb00 = src0->nb[0]; \
     const uint32_t nb01 = src0->nb[1]; \
     const uint32_t nb02 = src0->nb[2]; \
     const uint32_t nb03 = src0->nb[3]; \
                                        \
-    const uint32_t nb10 = src1->nb[0]; \
     const uint32_t nb11 = src1->nb[1]; \
     const uint32_t nb12 = src1->nb[2]; \
     const uint32_t nb13 = src1->nb[3]; \
                                        \
-    const uint32_t nb0 = dst->nb[0];   \
     const uint32_t nb1 = dst->nb[1];   \
     const uint32_t nb2 = dst->nb[2];   \
-    const uint32_t nb3 = dst->nb[3];   \
-                                       \
-    const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+    const uint32_t nb3 = dst->nb[3];
 
-static void binary_job_f32_per_thread(struct htp_ops_context * octx,
-                                      uint8_t *                spad_data,
-                                      uint32_t                 nth,
-                                      uint32_t                 ith,
-                                      enum htp_op              op) {
-    htp_binary_preamble;
+static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row,
+                                uint32_t ne01, uint32_t ne02) {
+    uint32_t i03, i02, i01, rem;
+    i03 = fastdiv(ir, &bctx->dim12_div);
+    rem = ir - i03 * (ne02 * ne01);
+    i02 = fastdiv(rem, &bctx->dim1_div);
+    i01 = rem - i02 * ne01;
 
-    const size_t src0_row_size = nb01;
-    const size_t src1_row_size = nb11;
-    const size_t dst_row_size  = nb1;
+    uint32_t rows_left = end_row - ir;
+    uint32_t block_limit = rows_left;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
-    const uint32_t src1_nrows = ne11 * ne12 * ne13;  // src1 rows
+    if (bctx->split_at_ne01) {
+        block_limit = MIN(block_limit, ne01 - i01);
+    }
+    if (bctx->split_at_ne02) {
+         uint32_t rows_in_plane = (ne02 * ne01) - rem;
+         block_limit = MIN(block_limit, rows_in_plane);
+    }
 
-    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
-    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+    return MIN(bctx->block_max, block_limit);
+}
 
-    // no work for this thread
-    if (src0_start_row >= src0_end_row) {
-        return;
+// Macro for scalar op switch
+#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \
+    switch (octx->op) { \
+        case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \
+        case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \
+        case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \
+        case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \
+        default: break; \
     }
 
-    uint64_t t1, t2;
-    t1 = HAP_perf_get_qtimer_count();
+// Macro for vector op switch (All Aligned)
+#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \
+    switch (octx->op) { \
+        case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
+        case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
+        case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
+        case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
+        default: break; \
+    }
 
-    int is_aligned = 1;
-    int opt_path   = 0;
-    if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
-        (0 == hex_is_aligned((void *) dst->data, VLEN))) {
-        is_aligned = 0;
+// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
+#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \
+    switch (octx->op) { \
+        case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
+        case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
+        case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
+        case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
+        default: break; \
     }
-    if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
-        opt_path = 1;
+
+// Macro for vector op switch (All Unaligned - generic loop used in element repeat)
+#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \
+    switch (octx->op) { \
+        case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
+        case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
+        case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
+        case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
+        default: break; \
     }
 
-    hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op];
+// 1. Scalar src1 (ne10 == 1)
+static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
+    htp_binary_preamble;
 
-    uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
+
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+    size_t src0_spad_half    = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half     = octx->dst_spad.size_per_thread  / 2;
+
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
+
+    // Preamble
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
+
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
+
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
+    }
 
-    const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
-    uint8_t * restrict dst_ptr        = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
+    // Main loop
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        // src1 indices (broadcast/repeat)
+        uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+        uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+        uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div);
+
+        uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+        uint32_t s1_stride = (ne11 == 1) ? 0 : nb11;
+
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
+            float val = *(float *)src1_ptr;
+            src1_ptr += s1_stride;
+            COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00);
+        }
 
-    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
 
-    const uint32_t ne02_ne01 = ne02 * ne01;
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
 
-    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
-        const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
-        const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
-        const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+             ir_prefetch += next_block_size;
+        }
+        ir += current_block_size;
+    }
+    dma_queue_flush(q);
+}
 
-        const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3);
-        const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2);
-        const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1);
+// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast
+static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
+    htp_binary_preamble;
 
-        const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
+
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+
+    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+    size_t src1_spad_half = octx->src1_spad.size_per_thread / 2;
+    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
+
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
+
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        uint32_t i13 = (ne13 == 1) ? 0 : i03;
+        uint32_t i12 = (ne12 == 1) ? 0 : i02;
+        uint32_t i11 = (ne11 == 1) ? 0 : i01;
+
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
+
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
+
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+        dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
+    }
 
-        if (ir + 1 < src0_end_row) {
-            hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1);
-            if (src1_row_size == src0_row_size) {
-                hex_l2fetch(src1_ptr, src1_row_size, src1_row_size, 1);
-            }
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+        uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
+            uint8_t * r_dst  = d_spad  + r * bctx->dst_row_size_aligned;
+            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
         }
 
-        const uint32_t nr0 = ne00 / ne10;
-        if (nr0 > 1) {
-            if ((1 == is_aligned) && (nr0 == ne00)) {
-                hvx_splat_f32_a(spad_data_th, *(float *) src1_ptr, nr0);
-            } else {
-                for (uint32_t r = 0; r < nr0; r++) {
-                    memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11);
-                }
-            }
-            func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, ne00);
-        } else {
-            func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+
+             uint32_t p13 = (ne13 == 1) ? 0 : p03;
+             uint32_t p12 = (ne12 == 1) ? 0 : p02;
+             uint32_t p11 = (ne11 == 1) ? 0 : p01;
+
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+             uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
+
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+             dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size);
+
+             ir_prefetch += next_block_size;
         }
-
-        src0_ptr += src0_row_size;
-        dst_ptr += dst_row_size;
+        ir += current_block_size;
     }
-
-    t2 = HAP_perf_get_qtimer_count();
-
-    FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
-         ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
-         (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+    dma_queue_flush(q);
 }
 
-static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
-                                             uint8_t *                spad_data,
-                                             uint32_t                 nth,
-                                             uint32_t                 ith,
-                                             hvx_elemwise_f32_func    func_HVX) {
+// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1)
+static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
     htp_binary_preamble;
 
-    const size_t src0_row_size = nb01;
-    const size_t src1_row_size = nb11;
-    const size_t dst_row_size  = nb1;
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
 
-    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
 
-    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
-    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
 
-    // no work for this thread
-    if (src0_start_row >= src0_end_row) {
-        return;
-    }
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
 
-    uint64_t t1, t2;
-    t1 = HAP_perf_get_qtimer_count();
+    void * s1_ptr = (void *) src1_spad;
 
-    const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
-    const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
-    uint8_t * restrict data_dst        = (uint8_t *) dst->data;
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
 
-    const uint32_t ne02_ne01  = ne02 * ne01;
-    for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
-        // src0 indices
-        const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
-        const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
-        const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
 
-        // src1 indices
-        const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]);
-        assert(i11 >= 0 && i11 < ne11);
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
-        float * restrict dst_ptr        = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1);
-        const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01);
-        const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11);
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
+    }
 
-        if (ir + 1 < src0_end_row) {
-            hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1);
-            if (src1_row_size == src0_row_size) {
-                hex_l2fetch(src1_ptr + ne10, src1_row_size, src1_row_size, 1);
-            }
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
+            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
+            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
         }
 
-        const uint32_t nr0 = ne00 / ne10;
-        if (nr0 > 1) {
-            for (uint32_t r = 0; r < nr0; r++) {
-                memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10);
-            }
-            func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data, ne00);
-        } else {
-            func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+             ir_prefetch += next_block_size;
         }
+        ir += current_block_size;
+    }
+    dma_queue_flush(q);
+}
+
+// 4. Vector Complex (ne10 == ne00, complex broadcast)
+static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
+    htp_binary_preamble;
+
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
+
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
+
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
+
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
+
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
+
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
     }
 
-    t2 = HAP_perf_get_qtimer_count();
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint32_t r_i01 = i01 + r;
+            uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+            uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+            uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
+
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
+
+            // Read src1 from DDR (unaligned)
+            COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00);
+        }
 
-    FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth,
-         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
-         src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1],
-         dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+             ir_prefetch += next_block_size;
+        }
+        ir += current_block_size;
+    }
+    dma_queue_flush(q);
 }
 
-static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
-    struct htp_ops_context * octx = (struct htp_ops_context *) data;
+// 5. Element Repeat (ne10 != ne00)
+static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
+    htp_binary_preamble;
 
-    switch (octx->op) {
-        case HTP_OP_MUL:
-        case HTP_OP_ADD:
-        case HTP_OP_SUB:
-            binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op);
-            break;
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
+
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
+
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
+
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
+
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
+
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
+    }
 
-        case HTP_OP_ADD_ID:
-            binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32);
-            break;
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint32_t r_i01 = i01 + r;
+            uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+            uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+            uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
+
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
+
+            // Repeat src1 row
+            for (uint32_t c = 0; c < ne00; c += ne10) {
+                uint32_t len = MIN(ne10, ne00 - c);
+                // Use UUU for speed and simplicity
+                COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len);
+            }
+        }
 
-        default:
-            FARF(ERROR, "Unknown Binary Op %u", octx->op);
-            break;
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+             ir_prefetch += next_block_size;
+        }
+        ir += current_block_size;
     }
+    dma_queue_flush(q);
 }
 
-static int execute_op_binary_f32(struct htp_ops_context * octx) {
-    int err = HTP_STATUS_OK;
+// 6. ADD_ID (src1 gathered via src2 indices)
+static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
+    struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+    struct htp_ops_context * octx = bctx->octx;
 
     const struct htp_tensor * src0 = &octx->src0;
     const struct htp_tensor * src1 = &octx->src1;
+    const struct htp_tensor * src2 = &octx->src2;
     struct htp_tensor *       dst  = &octx->dst;
 
-    worker_callback_t binary_op_func;
-    const char *      op_type = NULL;
-
-    switch (octx->op) {
-        case HTP_OP_MUL:
-            binary_op_func = binary_job_dispatcher_f32;
-            op_type        = "mul-f32";
-            break;
-
-        case HTP_OP_ADD:
-            binary_op_func = binary_job_dispatcher_f32;
-            op_type        = "add-f32";
-            break;
-
-        case HTP_OP_SUB:
-            binary_op_func = binary_job_dispatcher_f32;
-            op_type        = "sub-f32";
-            break;
-
-        case HTP_OP_ADD_ID:
-            binary_op_func = binary_job_dispatcher_f32;
-            op_type        = "add-id-f32";
-            break;
-
-        default:
-            FARF(ERROR, "Unsupported binary-Op %u\n", octx->op);
-            return HTP_STATUS_NO_SUPPORT;
+    const uint32_t ne00 = src0->ne[0];
+    const uint32_t ne01 = src0->ne[1];
+    const uint32_t ne02 = src0->ne[2];
+    const uint32_t ne03 = src0->ne[3];
+    const uint32_t ne11 = src1->ne[1]; // for bounds check
+
+    const uint32_t nb01 = src0->nb[1];
+    const uint32_t nb02 = src0->nb[2];
+    const uint32_t nb03 = src0->nb[3];
+    const uint32_t nb11 = src1->nb[1]; // src1 row stride
+    const uint32_t nb1 = dst->nb[1];
+    const uint32_t nb2 = dst->nb[2];
+    const uint32_t nb3 = dst->nb[3];
+
+    const uint32_t total_rows = ne01 * ne02 * ne03;
+    const uint32_t start_row = bctx->nrows_per_thread * ith;
+    const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
+    if (start_row >= end_row) return;
+
+    uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+    uint8_t * dst_spad_base  = octx->dst_spad.data  + (ith * octx->dst_spad.size_per_thread);
+    size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+    size_t dst_spad_half  = octx->dst_spad.size_per_thread  / 2;
+
+    dma_queue * q = octx->ctx->dma[ith];
+    uint32_t ir_prefetch = start_row;
+    int spad_idx = 0;
+
+    for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+        uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+        rem = ir_prefetch - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+        uint8_t * dst_curr  = (uint8_t *)dst->data  + i03 * nb3  + i02 * nb2  + i01 * nb1;
+
+        uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+        uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
+
+        dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+        ir_prefetch += current_block_size;
+        spad_idx ^= 1;
+    }
+
+    for (uint32_t ir = start_row; ir < end_row; ) {
+        uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+        uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+        uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+        uint32_t i03, i02, i01, rem;
+        i03 = fastdiv(ir, &bctx->dim12_div);
+        rem = ir - i03 * (ne02 * ne01);
+        i02 = fastdiv(rem, &bctx->dim1_div);
+        i01 = rem - i02 * ne01;
+
+        for (uint32_t r = 0; r < current_block_size; r++) {
+            uint32_t r_i01 = i01 + r; // linear within block since we split at ne01
+
+            const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]);
+
+            uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11;
+            uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+            uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
+
+            hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00);
+        }
+
+        uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+        if (ir_prefetch < end_row) {
+             uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+             uint32_t p03, p02, p01, prem;
+             p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+             prem = ir_prefetch - p03 * (ne02 * ne01);
+             p02 = fastdiv(prem, &bctx->dim1_div);
+             p01 = prem - p02 * ne01;
+             uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+             ir_prefetch += next_block_size;
+        }
+        ir += current_block_size;
     }
+    dma_queue_flush(q);
+}
+
+static int execute_op_binary_f32(struct htp_ops_context * octx) {
+    const struct htp_tensor * src0 = &octx->src0;
+    const struct htp_tensor * src1 = &octx->src1;
+    struct htp_tensor *       dst  = &octx->dst;
 
-    const int      n_threads  = octx->n_threads;
+    const uint32_t n_threads  = octx->n_threads;
     const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
 
-    const size_t src0_row_size = src0->nb[1];
-    const size_t src1_row_size = src1->nb[1];
-    const size_t dst_row_size  = dst->nb[1];
+    // Use packed row sizes for VTCM allocation
+    const size_t src0_row_size = src0->ne[0] * sizeof(float);
+    const size_t src1_row_size = src1->ne[0] * sizeof(float);
+    const size_t dst_row_size  = dst->ne[0] * sizeof(float);
+
+    // Align to VLEN
+    const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+    const size_t dst_row_size_aligned  = hex_round_up(dst_row_size, VLEN);
+    size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
+
+    bool is_add_id = (octx->op == HTP_OP_ADD_ID);
+    bool is_scalar = !is_add_id && (src1->ne[0] == 1);
+
+    // Determine which kernel we will use to alloc memory and dispatch
+    bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] &&
+               (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
+               (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
+               (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
+
+    bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
+    bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]);
+    bool use_repeat  = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]);
+
+    size_t spad_row_total;
+    if (is_scalar) {
+        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+    } else if (is_row_bcast) {
+        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+    } else if (use_vector_same) {
+        spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);
+    } else if (is_add_id) {
+        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly
+    } else {
+        spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+    }
 
-    // VTCM scratchpads for all tensors
-    octx->dst_spad.size  = hex_round_up(dst_row_size, 128) * n_threads;
-    octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
-    octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
+    size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);
+    // Adjust for static src1 in row_bcast case
+    if (is_row_bcast) {
+        size_t needed_static = src1_row_size_aligned;
+        if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL;
+        size_t avail = octx->ctx->vtcm_size - needed_static;
+        rows_per_buffer = avail / (n_threads * spad_row_total);
+    }
 
-    size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+    if (rows_per_buffer < 1) {
+         FARF(ERROR, "binary-f32: VTCM too small\n");
+         return HTP_STATUS_VTCM_TOO_SMALL;
+    }
 
-    FARF(HIGH,
-         "%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
-         op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
-         src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
-         octx->dst_spad.size);
+    octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;
+    octx->dst_spad.size_per_thread  = rows_per_buffer * 2 * dst_row_size_aligned;
 
-    // Make sure the reserved vtcm size is sufficient
-    if (octx->ctx->vtcm_size < spad_size) {
-        FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
-             octx->ctx->vtcm_size, spad_size);
+    if (is_scalar || use_complex || use_repeat || is_add_id) {
+        octx->src1_spad.size_per_thread = 0;
+    } else if (is_row_bcast) {
+        octx->src1_spad.size_per_thread = 0;
+    } else {
+        octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;
+    }
+
+    octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
+    if (is_row_bcast) {
+        octx->src1_spad.size = src1_row_size_aligned;
+    } else {
+        octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
+    }
+    octx->dst_spad.size  = n_threads * octx->dst_spad.size_per_thread;
+
+    if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {
         return HTP_STATUS_VTCM_TOO_SMALL;
     }
 
@@ -305,39 +757,71 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
     octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
     octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
 
-    if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        uint32_t n_jobs = MIN(n_threads, src0_nrows);
+    if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+        return HTP_STATUS_OK;
+    }
+
+    uint32_t n_jobs = MIN(n_threads, src0_nrows);
 
-        octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+    dma_queue * q = octx->ctx->dma[0];
+    if (is_row_bcast) {
+        dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1);
+    }
 
-        octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]);
-        octx->src0_div3  = init_fastdiv_values(src0->ne[3]);
-        octx->src0_div2  = init_fastdiv_values(src0->ne[2]);
-        octx->src0_div1  = init_fastdiv_values(src0->ne[1]);
+    struct htp_binary_context bctx;
+    bctx.octx = octx;
+    bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+    bctx.block_max = rows_per_buffer;
+    bctx.src0_row_size_aligned = src0_row_size_aligned;
+    bctx.src1_row_size_aligned = src1_row_size_aligned;
+    bctx.dst_row_size_aligned  = dst_row_size_aligned;
 
-        octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]);
-        octx->src1_div3  = init_fastdiv_values(src1->ne[3]);
-        octx->src1_div2  = init_fastdiv_values(src1->ne[2]);
-        octx->src1_div1  = init_fastdiv_values(src1->ne[1]);
+    bctx.dim1_div = init_fastdiv_values(src0->ne[1]);
+    bctx.dim2_div = init_fastdiv_values(src0->ne[2]);
+    bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
 
-        worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
-    }
+    bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
+    bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
+    bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
 
-    return err;
-}
+    bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);
+    bool dst_contig_dim1  = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
 
-int op_binary(struct htp_ops_context * octx) {
-    int err = HTP_STATUS_OK;
+    bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);
+    bool dst_contig_dim2  = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
 
-    switch (octx->src0.type) {
-        case HTP_TYPE_F32:
-            err = execute_op_binary_f32(octx);
-            break;
+    bctx.split_at_ne01 = (src0->ne[2] > 1) &&
+                         ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
 
-        default:
-            err = HTP_STATUS_NO_SUPPORT;
-            break;
+    bctx.split_at_ne02 = (src0->ne[3] > 1) &&
+                         ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
+
+    // Precompute specific kernel parameters
+    if (use_vector_same) {
+        bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1];
+        bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer;
     }
 
-    return err;
+    worker_callback_t worker_func;
+    if (is_add_id) worker_func = binary_job_add_id;
+    else if (is_scalar) worker_func = binary_job_scalar;
+    else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
+    else if (use_vector_same) worker_func = binary_job_vector_same_shape;
+    else if (use_complex) worker_func = binary_job_vector_complex;
+    else worker_func = binary_job_element_repeat;
+
+    if (is_row_bcast) {
+        dma_queue_pop(q);
+    }
+
+    worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs);
+
+    return HTP_STATUS_OK;
+}
+
+int op_binary(struct htp_ops_context * octx) {
+    if (octx->src0.type == HTP_TYPE_F32) {
+        return execute_op_binary_f32(octx);
+    }
+    return HTP_STATUS_NO_SUPPORT;
 }
index f49e8ee4478367dab9f3912960713fdc3f96264b..25403bb1126538a6c0103a9de3f671077aa45cdd 100644 (file)
@@ -42,32 +42,36 @@ enum htp_data_type {
     HTP_TYPE_COUNT
 };
 
-// These values are manually translated over to HTP
-// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
+// Do not reorder first 4 (used as an index)
 enum htp_op {
-    HTP_OP_MUL            = 0,
-    HTP_OP_ADD            = 1,
-    HTP_OP_SUB            = 2,
-    HTP_OP_DIV            = 3,
-    HTP_OP_MUL_MAT        = 4,
-    HTP_OP_MUL_MAT_ID     = 5,
-    HTP_OP_RMS_NORM       = 6,
-    HTP_OP_UNARY_SILU     = 7,
-    HTP_OP_UNARY_GELU     = 8,
-    HTP_OP_GLU_SWIGLU     = 9,
-    HTP_OP_GLU_SWIGLU_OAI = 10,
-    HTP_OP_SOFTMAX        = 11,
-    HTP_OP_ADD_ID         = 12,
-    HTP_OP_ROPE           = 13,
-    HTP_OP_FLASH_ATTN_EXT = 14,
-    HTP_OP_SET_ROWS       = 15,
-    HTP_OP_SCALE          = 16,
-    HTP_OP_GET_ROWS       = 17,
-    HTP_OP_CPY            = 18,
+    HTP_OP_MUL = 0,
+    HTP_OP_ADD = 1,
+    HTP_OP_SUB = 2,
+    HTP_OP_DIV = 3,
+    HTP_OP_MUL_MAT,
+    HTP_OP_MUL_MAT_ID,
+    HTP_OP_RMS_NORM,
+    HTP_OP_UNARY_SILU,
+    HTP_OP_UNARY_GELU,
+    HTP_OP_GLU_SWIGLU,
+    HTP_OP_GLU_SWIGLU_OAI,
+    HTP_OP_GLU_GEGLU,
+    HTP_OP_SOFTMAX,
+    HTP_OP_ADD_ID,
+    HTP_OP_ROPE,
+    HTP_OP_FLASH_ATTN_EXT,
+    HTP_OP_SET_ROWS,
+    HTP_OP_GET_ROWS,
+    HTP_OP_SCALE,
+    HTP_OP_CPY,
+    HTP_OP_ARGSORT,
+    HTP_OP_SQR,
+    HTP_OP_SQRT,
+    HTP_OP_SUM_ROWS,
     INVALID
 };
 
-static inline size_t htp_type_block_size(uint32_t t) {
+static inline size_t htp_t_block_size(uint32_t t) {
     switch (t) {
         case HTP_TYPE_F32:
             return 1;
@@ -103,22 +107,6 @@ static inline size_t htp_type_nbytes(uint32_t t) {
     return 0;
 }
 
-static const char * htp_type_name(uint32_t t) {
-    switch (t) {
-        case HTP_TYPE_F32:
-            return "fp32";
-        case HTP_TYPE_F16:
-            return "fp16";
-        case HTP_TYPE_Q4_0:
-            return "q4_0";
-        case HTP_TYPE_Q8_0:
-            return "q8_0";
-        case HTP_TYPE_MXFP4:
-            return "mxfp4";
-    }
-    return 0;
-}
-
 // Internal types
 #define QK_Q4_0x4x2  256  // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
 #define QK_Q8_0x4x2  256  // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
index 602a2775a473c74d6947b871defa6d90c9558572..c0d72587ce5d3831261f2ceb64397a40227b571a 100644 (file)
@@ -90,6 +90,7 @@ int op_matmul(struct htp_ops_context * octx);
 int op_matmul_id(struct htp_ops_context * octx);
 int op_binary(struct htp_ops_context * octx);
 int op_unary(struct htp_ops_context * octx);
+int op_sum_rows(struct htp_ops_context * octx);
 int op_activations(struct htp_ops_context * octx);
 int op_softmax(struct htp_ops_context * octx);
 int op_add_id(struct htp_ops_context * octx);
@@ -98,5 +99,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
 int op_set_rows(struct htp_ops_context * octx);
 int op_get_rows(struct htp_ops_context * octx);
 int op_cpy(struct htp_ops_context * octx);
+int op_argsort(struct htp_ops_context * octx);
 
 #endif /* HTP_OPS_H */
index 3449739a4fac068a67a60e244a2002a5c5515571..2577cdd0418c6e5188739fc06563cdcf54b783ce 100644 (file)
 #define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
 #endif
 
-// ADD variants
-
-static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src0 % 128 == 0);
-    assert((unsigned long) src1 % 128 == 0);
-    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD);
-}
-
-static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src0 % 128 == 0);
-    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD);
-}
-
-static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((unsigned long) src0 % 128 == 0);
-    assert((unsigned long) src1 % 128 == 0);
-    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD);
-}
-
-static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD);
-}
-
-// SUB variants
-
-static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src0 % 128 == 0);
-    assert((unsigned long) src1 % 128 == 0);
-    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB);
-}
-
-static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src0 % 128 == 0);
-    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB);
-}
-
-static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((unsigned long) src0 % 128 == 0);
-    assert((unsigned long) src1 % 128 == 0);
-    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB);
-}
-
-static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB);
-}
-
-// MUL variants
-
-static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src0 % 128 == 0);
-    assert((unsigned long) src1 % 128 == 0);
-    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL);
-}
-
-static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src0 % 128 == 0);
-    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL);
-}
-
-static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((unsigned long) src0 % 128 == 0);
-    assert((unsigned long) src1 % 128 == 0);
-    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL);
-}
-
-static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL);
-}
-
-// Dispatchers
-
-static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
-    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
-        if (hex_is_aligned((void *) src1, 128)) {
-            hvx_add_f32_aa(dst, src0, src1, num_elems);
-        } else {
-            hvx_add_f32_au(dst, src0, src1, num_elems);
-        }
-    } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
-        hvx_add_f32_ua(dst, src0, src1, num_elems);
-    } else {
-        hvx_add_f32_uu(dst, src0, src1, num_elems);
-    }
-}
-
-static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
-    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
-        if (hex_is_aligned((void *) src1, 128)) {
-            hvx_sub_f32_aa(dst, src0, src1, num_elems);
-        } else {
-            hvx_sub_f32_au(dst, src0, src1, num_elems);
-        }
-    } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
-        hvx_sub_f32_ua(dst, src0, src1, num_elems);
-    } else {
-        hvx_sub_f32_uu(dst, src0, src1, num_elems);
-    }
-}
-
-static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
-    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
-        if (hex_is_aligned((void *) src1, 128)) {
-            hvx_mul_f32_aa(dst, src0, src1, num_elems);
-        } else {
-            hvx_mul_f32_au(dst, src0, src1, num_elems);
-        }
-    } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
-        hvx_mul_f32_ua(dst, src0, src1, num_elems);
-    } else {
-        hvx_mul_f32_uu(dst, src0, src1, num_elems);
-    }
-}
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
+static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src0 % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src0 % 128 == 0); \
+    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src0 % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src0 % 128 == 0); \
+    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src1 % 128 == 0); \
+    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
+} \
+
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
+
+// Dispatcher logic
+#define HVX_BINARY_DISPATCHER(OP_NAME) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
+    if (hex_is_aligned((void *) dst, 128)) { \
+        if (hex_is_aligned((void *) src0, 128)) { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_aau(dst, src0, src1, num_elems); \
+        } else { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_auu(dst, src0, src1, num_elems); \
+        } \
+    } else { \
+        if (hex_is_aligned((void *) src0, 128)) { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_uau(dst, src0, src1, num_elems); \
+        } else { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_uuu(dst, src0, src1, num_elems); \
+        } \
+    } \
+}
+
+HVX_BINARY_DISPATCHER(hvx_add_f32)
+HVX_BINARY_DISPATCHER(hvx_sub_f32)
+HVX_BINARY_DISPATCHER(hvx_mul_f32)
 
 // Mul-Mul Optimized
-
 static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
     assert((unsigned long) dst % 128 == 0);
     assert((unsigned long) src0 % 128 == 0);
@@ -443,6 +392,68 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
     }
 }
 
+//
+// Square
+//
+
+#define hvx_sqr_loop_body(dst_type, src_type, vec_store)           \
+    do {                                                                   \
+        dst_type * restrict vdst  = (dst_type *) dst;                      \
+        src_type * restrict vsrc = (src_type *) src;                       \
+                                                                           \
+        const uint32_t elem_size = sizeof(float);                          \
+        const uint32_t epv  = 128 / elem_size;                             \
+        const uint32_t nvec = n / epv;                                     \
+        const uint32_t nloe = n % epv;                                     \
+                                                                           \
+        uint32_t i = 0;                                                    \
+                                                                           \
+        _Pragma("unroll(4)")                                               \
+        for (; i < nvec; i++) {                                            \
+            vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]);                        \
+        }                                                                  \
+        if (nloe) {                                                        \
+            HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]);                   \
+            vec_store((void *) &vdst[i], nloe * elem_size, v);             \
+        }                                                                  \
+    } while(0)
+
+static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) src % 128 == 0);
+    hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
+    if (hex_is_aligned((void *) dst, 128)) {
+        if (hex_is_aligned((void *) src, 128)) {
+            hvx_sqr_f32_aa(dst, src, num_elems);
+        } else {
+            hvx_sqr_f32_au(dst, src, num_elems);
+        }
+    } else {
+        if (hex_is_aligned((void *) src, 128)) {
+            hvx_sqr_f32_ua(dst, src, num_elems);
+        } else {
+            hvx_sqr_f32_uu(dst, src, num_elems);
+        }
+    }
+}
+
 #undef HVX_OP_ADD
 #undef HVX_OP_SUB
 #undef HVX_OP_MUL
@@ -453,5 +464,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
 #undef hvx_scalar_loop_body
 #undef HVX_OP_MIN_SCALAR
 #undef HVX_OP_CLAMP_SCALAR
+#undef DEFINE_HVX_BINARY_OP_VARIANTS
+#undef HVX_BINARY_DISPATCHER
 
 #endif // HVX_ARITH_H
index ffa6e18e6456a0caf014c6910c1b21dfef9f2afa..12a1b7f12889b4eafa0809fb70cc08a86b228f59 100644 (file)
@@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) {
     return x;
 }
 
+static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
+    int32_t __attribute__((aligned(128))) x;
+    hvx_vec_store_a(&x, 4, v);
+    return x;
+}
+
 static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
     // abs by clearing the fp16 sign bit
     HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
index 6b617b76177aa4eb56ea74a79ee9d6f12982fbe7..ae0dbed030640e73363403c83a15dea10ba23a2a 100644 (file)
@@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr
         dst_type * restrict vdst = (dst_type *) dst;                                \
         src_type * restrict vsrc = (src_type *) src;                                \
                                                                                     \
-        const HVX_Vector zero = Q6_V_vsplat_R(0);                                   \
-                                                                                    \
         const uint32_t elem_size = sizeof(__fp16);                                  \
         const uint32_t epv  = 128 / elem_size;                                      \
         const uint32_t nvec = n / epv;                                              \
diff --git a/src/ggml-hexagon/htp/hvx-div.h b/src/ggml-hexagon/htp/hvx-div.h
new file mode 100644 (file)
index 0000000..7dae012
--- /dev/null
@@ -0,0 +1,116 @@
+#ifndef HVX_DIV_H
+#define HVX_DIV_H
+
+#include <HAP_farf.h>
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "hvx-base.h"
+#include "hex-utils.h"
+#include "hvx-inverse.h"
+#include "hvx-arith.h"
+
+#if __HVX_ARCH__ < 79
+#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#else
+#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#endif
+
+#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store)             \
+    do {                                                                             \
+        dst_type * restrict vdst = (dst_type *) dst;                                 \
+        src0_type * restrict vsrc0 = (src0_type *) src0;                             \
+        src1_type * restrict vsrc1 = (src1_type *) src1;                             \
+                                                                                     \
+        const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000);                   \
+                                                                                     \
+        const uint32_t nvec = n / VLEN_FP32;                                         \
+        const uint32_t nloe = n % VLEN_FP32;                                         \
+                                                                                     \
+        uint32_t i = 0;                                                              \
+                                                                                     \
+        _Pragma("unroll(4)")                                                         \
+        for (; i < nvec; i++) {                                                      \
+            HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
+            HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1);                         \
+            vdst[i] = res;                                                           \
+        }                                                                            \
+        if (nloe) {                                                                  \
+            HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
+            HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1);                         \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res);                   \
+        }                                                                            \
+    } while(0)
+
+// 3-letter suffix variants
+static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+    assert((uintptr_t) dst % 128 == 0);
+    assert((uintptr_t) src0 % 128 == 0);
+    assert((uintptr_t) src1 % 128 == 0);
+    hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+    assert((uintptr_t) dst % 128 == 0);
+    assert((uintptr_t) src0 % 128 == 0);
+    hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+    assert((uintptr_t) dst % 128 == 0);
+    assert((uintptr_t) src1 % 128 == 0);
+    hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+    assert((uintptr_t) dst % 128 == 0);
+    hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+    assert((uintptr_t) src0 % 128 == 0);
+    assert((uintptr_t) src1 % 128 == 0);
+    hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+    assert((uintptr_t) src0 % 128 == 0);
+    hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+    assert((uintptr_t) src1 % 128 == 0);
+    hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+    hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
+    if (hex_is_aligned((void *) dst, 128)) {
+        if (hex_is_aligned((void *) src0, 128)) {
+            if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
+            else                                    hvx_div_f32_aau(dst, src0, src1, num_elems);
+        } else {
+            if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
+            else                                    hvx_div_f32_auu(dst, src0, src1, num_elems);
+        }
+    } else {
+        if (hex_is_aligned((void *) src0, 128)) {
+            if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
+            else                                    hvx_div_f32_uau(dst, src0, src1, num_elems);
+        } else {
+            if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
+            else                                    hvx_div_f32_uuu(dst, src0, src1, num_elems);
+        }
+    }
+}
+
+#undef HVX_OP_MUL
+
+#endif // HVX_DIV_H
index 1b4aaff0c922cacd56e59048d899512e6afce407..095193277ea9448d1291033c703ed626ad124f26 100644 (file)
@@ -91,6 +91,27 @@ static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
         }                                                       \
     } while(0)
 
+#define hvx_tanh_loop_body(dst_type, src_type, vec_store)       \
+    do {                                                        \
+        dst_type * restrict vdst = (dst_type *) dst;            \
+        src_type * restrict vsrc = (src_type *) src;            \
+                                                                \
+        const uint32_t epv  = 128 / sizeof(float);              \
+        const uint32_t nvec = n / epv;                          \
+        const uint32_t nloe = n % epv;                          \
+                                                                \
+        uint32_t i = 0;                                         \
+                                                                \
+        _Pragma("unroll(4)")                                    \
+        for (; i < nvec; i++) {                                 \
+             vdst[i] = hvx_vec_tanh_f32(vsrc[i]);               \
+        }                                                       \
+        if (nloe) {                                             \
+             HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]);        \
+             vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
+        }                                                       \
+    } while(0)
+
 static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
     assert((unsigned long) dst % 128 == 0);
     assert((unsigned long) src % 128 == 0);
@@ -111,4 +132,10 @@ static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * re
     hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
 }
 
+static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
 #endif /* HVX_SIGMOID_H */
index 28ee9f68d3e6f05b19848a955f36a758feec0c1f..e31a1006d213e29dce37bcaa2a2be79a56893cb8 100644 (file)
 #define RSQRT_ONE_HALF     0x3f000000  // 0.5
 #define RSQRT_THREE_HALVES 0x3fc00000  // 1.5
 
+#if __HVX_ARCH__ < 79
+#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#else
+#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#endif
+
 static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
     //Algorithm :
     //  x2 = input*0.5
     //  y  = * (long *) &input
-    //  y  = 0x5f3759df - (y>>2)
+    //  y  = 0x5f3759df - (y>>1)
     //  y  = y*(threehalfs - x2*y*y)
 
     HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
@@ -57,4 +63,64 @@ static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
     return Q6_Vsf_equals_Vqf32(temp);
 }
 
+// Compute sqrt(x) as x*inv_sqrt(x)
+#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store)                \
+    do {                                                                     \
+        dst_type * restrict vdst = (dst_type *) dst;                         \
+        src_type * restrict vsrc = (src_type *) src;                         \
+                                                                             \
+        const uint32_t nvec = n / VLEN_FP32;                                 \
+        const uint32_t nloe = n % VLEN_FP32;                                 \
+                                                                             \
+        uint32_t i = 0;                                                      \
+                                                                             \
+        _Pragma("unroll(4)")                                                 \
+        for (; i < nvec; i++) {                                              \
+            HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]);                \
+            HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]);             \
+            vdst[i] = sqrt_res;                                              \
+        }                                                                    \
+        if (nloe) {                                                          \
+            HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]);                \
+            HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]);             \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res);      \
+        }                                                                    \
+    } while(0)
+
+static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    assert((unsigned long) src % 128 == 0);
+    hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) dst % 128 == 0);
+    hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    assert((unsigned long) src % 128 == 0);
+    hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+    hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
+    if ((unsigned long) dst % 128 == 0) {
+        if ((unsigned long) src % 128 == 0) {
+            hvx_sqrt_f32_aa(dst, src, num_elems);
+        } else {
+            hvx_sqrt_f32_au(dst, src, num_elems);
+        }
+    } else {
+        if ((unsigned long) src % 128 == 0) {
+            hvx_sqrt_f32_ua(dst, src, num_elems);
+        } else {
+            hvx_sqrt_f32_uu(dst, src, num_elems);
+        }
+    }
+}
+
 #endif /* HVX_SQRT_H */
index 7b79a5ea3221a3905c3343c6ccb2dce7c8268b90..a518ad37331dee7040ea60497bbd9f1b7caf40e4 100644 (file)
@@ -12,6 +12,7 @@
 #include "hvx-sigmoid.h"
 #include "hvx-sqrt.h"
 #include "hvx-arith.h"
+#include "hvx-div.h"
 #include "hvx-base.h"
 
 #endif /* HVX_UTILS_H */
index e28a67a95dc1a5f8b6bd5a6ee06ef84c7bca45d6..62708eee5cfab375e745ae3417751eaab0159e18 100644 (file)
@@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context *     ctx,
     send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
 }
 
+static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[1];
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[0].fd     = bufs[1].fd;
+    rsp_bufs[0].ptr    = bufs[1].ptr;
+    rsp_bufs[0].offset = bufs[1].offset;
+    rsp_bufs[0].size   = bufs[1].size;
+    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.dst.data  = (uint32_t) bufs[1].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_argsort(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
 static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
     struct dspqueue_buffer rsp_bufs[1];
 
@@ -679,6 +718,45 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re
     send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
 }
 
+static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[0].fd     = bufs[1].fd;
+    rsp_bufs[0].ptr    = bufs[1].ptr;
+    rsp_bufs[0].offset = bufs[1].offset;
+    rsp_bufs[0].size   = bufs[1].size;
+    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.dst.data  = (uint32_t) bufs[1].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+    if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+        rsp_status = op_sum_rows(&octx);
+        vtcm_release(ctx);
+    }
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
 static void proc_activations_req(struct htp_context *     ctx,
                                  struct htp_general_req * req,
                                  struct dspqueue_buffer * bufs,
@@ -951,6 +1029,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
             case HTP_OP_MUL:
             case HTP_OP_ADD:
             case HTP_OP_SUB:
+            case HTP_OP_DIV:
                 if (n_bufs != 3) {
                     FARF(ERROR, "Bad binary-req buffer list");
                     continue;
@@ -968,6 +1047,25 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
                 proc_unary_req(ctx, &req, bufs);
                 break;
 
+            case HTP_OP_SQR:
+            case HTP_OP_SQRT:
+                if (n_bufs != 2) {
+                    FARF(ERROR, "Bad unary-req buffer list");
+                    continue;
+                }
+
+                proc_unary_req(ctx, &req, bufs);
+                break;
+
+            case HTP_OP_SUM_ROWS:
+                if (n_bufs != 2) {
+                    FARF(ERROR, "Bad unary-req buffer list");
+                    continue;
+                }
+
+                proc_sum_rows_req(ctx, &req, bufs);
+                break;
+
             case HTP_OP_UNARY_SILU:
             case HTP_OP_UNARY_GELU:
                 if (n_bufs != 2) {
@@ -980,6 +1078,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
             case HTP_OP_GLU_SWIGLU:
             case HTP_OP_GLU_SWIGLU_OAI:
             case HTP_OP_SOFTMAX:
+            case HTP_OP_GLU_GEGLU:
                 if ((n_bufs != 2) && (n_bufs != 3)) {
                     FARF(ERROR, "Bad act-req buffer list");
                     continue;
@@ -1035,6 +1134,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
                 proc_cpy_req(ctx, &req, bufs);
                 break;
 
+            case HTP_OP_ARGSORT:
+                if (n_bufs != 2) {
+                    FARF(ERROR, "Bad argsort-req buffer list");
+                    continue;
+                }
+                proc_argsort_req(ctx, &req, bufs);
+                break;
+
             default:
                 FARF(ERROR, "Unknown Op %u", req.op);
                 break;
diff --git a/src/ggml-hexagon/htp/sum-rows-ops.c b/src/ggml-hexagon/htp/sum-rows-ops.c
new file mode 100644 (file)
index 0000000..62e45da
--- /dev/null
@@ -0,0 +1,115 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <string.h>
+#include <math.h>
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+
+#define sum_rows_preamble                       \
+    struct htp_tensor *src0 =  &octx->src0;\
+    struct htp_tensor *dst  = &octx->dst;  \
+                                           \
+    const uint32_t ne00 = src0->ne[0];     \
+    const uint32_t ne01 = src0->ne[1];     \
+    const uint32_t ne02 = src0->ne[2];     \
+    const uint32_t ne03 = src0->ne[3];     \
+                                           \
+    const uint32_t nb00 = src0->nb[0];     \
+    const uint32_t nb01 = src0->nb[1];     \
+    const uint32_t nb02 = src0->nb[2];     \
+    const uint32_t nb03 = src0->nb[3];     \
+                                           \
+    const uint32_t  ne0 = dst->ne[0];      \
+    const uint32_t  ne1 = dst->ne[1];      \
+    const uint32_t  ne2 = dst->ne[2];      \
+    const uint32_t  ne3 = dst->ne[3];      \
+                                           \
+    const uint32_t  nb0 = dst->nb[0];      \
+    const uint32_t  nb1 = dst->nb[1];      \
+    const uint32_t  nb2 = dst->nb[2];      \
+    const uint32_t  nb3 = dst->nb[3];      \
+
+static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+    sum_rows_preamble;
+
+    const uint32_t src0_nrows_per_thread  = octx->src0_nrows_per_thread;
+    const size_t src0_row_size = nb01;
+    const size_t dst_row_size  = nb1;
+
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;  // src0 rows
+
+    const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+    const uint32_t src0_end_row   = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+    // no work for this thread
+    if (src0_start_row >= src0_end_row) {
+        return HTP_STATUS_OK;
+    }
+
+    int opt_path   = 0;
+    if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
+        opt_path = 1;
+    }
+
+    const uint8_t * restrict data_src = (const uint8_t *) src0->data;
+    uint8_t * restrict data_dst       = (uint8_t *) dst->data;
+
+    const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
+    float * restrict dst_th       = (float *) (data_dst + (src0_start_row * dst_row_size));
+
+    for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
+        const float * restrict src_local = src_th + (ir * ne00);
+
+        if (ir + 1 < src0_nrows_per_thread) {
+            hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
+        }
+
+        if (1 == opt_path) {
+            dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
+        } else {
+            dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
+        }
+    }
+
+    return HTP_STATUS_OK;
+}
+
+static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
+    sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
+}
+
+int op_sum_rows(struct htp_ops_context * octx) {
+    sum_rows_preamble;
+
+    if (octx->src0.type != HTP_TYPE_F32) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+        return HTP_STATUS_OK;
+    }
+
+    const int      n_threads  = octx->n_threads;
+    const uint32_t src0_nrows = ne01 * ne02 * ne03;
+
+    uint32_t n_jobs = MIN(n_threads, src0_nrows);
+    octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+
+    worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
+
+    return HTP_STATUS_OK;
+}
+
index 1a27cb6e63e533bb7580233d2807bb5afcdd03f5..ce879bf03701d7897521994bfe4b5947a1e1de0d 100644 (file)
@@ -132,6 +132,56 @@ static void rms_norm_htp_f32(const float * restrict src,
     }
 }
 
+static void sqr_htp_f32(const float * restrict src,
+                          float * restrict dst,
+                          uint8_t * restrict spad,
+                          const uint32_t num_rows,
+                          const uint32_t row_elems,
+                          const size_t   row_size,
+                          int32_t *      op_params,
+                          int            opt_path) {
+
+    for (uint32_t ir = 0; ir < num_rows; ir++) {
+        const float * restrict src_local = src + (ir * row_elems);
+        float * restrict dst_local       = dst + (ir * row_elems);
+
+        if (ir + 1 < num_rows) {
+            hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
+        }
+
+        if (1 == opt_path) {
+            hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+        } else {
+            hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+        }
+    }
+}
+
+static void sqrt_htp_f32(const float * restrict src,
+                          float * restrict dst,
+                          uint8_t * restrict spad,
+                          const uint32_t num_rows,
+                          const uint32_t row_elems,
+                          const size_t   row_size,
+                          int32_t *      op_params,
+                          int            opt_path) {
+
+    for (uint32_t ir = 0; ir < num_rows; ir++) {
+        const float * restrict src_local = src + (ir * row_elems);
+        float * restrict dst_local       = dst + (ir * row_elems);
+
+        if (ir + 1 < num_rows) {
+            hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
+        }
+
+        if (1 == opt_path) {
+            hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+        } else {
+            hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+        }
+    }
+}
+
 static void unary_job_f32_per_thread(const struct htp_tensor * src,
                                      struct htp_tensor *       dst,
                                      uint8_t *                 spad,
@@ -181,6 +231,12 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
         case HTP_OP_SCALE:
             scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
             break;
+        case HTP_OP_SQR:
+            sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
+            break;
+        case HTP_OP_SQRT:
+            sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
+            break;
 
         default:
             break;
@@ -218,6 +274,14 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
             unary_op_func = unary_job_dispatcher_f32;
             op_type       = "scale-f32";
             break;
+        case HTP_OP_SQR:
+            unary_op_func = unary_job_dispatcher_f32;
+            op_type       = "sqr-f32";
+            break;
+        case HTP_OP_SQRT:
+            unary_op_func = unary_job_dispatcher_f32;
+            op_type       = "sqrt-f32";
+            break;
 
         default:
             FARF(ERROR, "Unsupported unary Op %u\n", octx->op);