]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
hexagon: add fp16 support for binary ops: add,sub,mul,div (llama/20139)
authorYardenTal44 <redacted>
Fri, 6 Mar 2026 02:29:13 +0000 (04:29 +0200)
committerGeorgi Gerganov <redacted>
Sun, 15 Mar 2026 19:50:13 +0000 (21:50 +0200)
* hexagon: add fp16 support for binary ops: add,sub,mul,div

* hexagon: fix test-backend-ops failures for fp16 binary ops on older arches (<v79)

* hexagon: decide on n_threads (aka n_jobs) early to avoid overallocating scratchpad

* snapdragon: fix readme link

---------

Co-authored-by: Max Krasnyansky <redacted>
15 files changed:
src/ggml-hexagon/ggml-hexagon.cpp
src/ggml-hexagon/htp/act-ops.c
src/ggml-hexagon/htp/argsort-ops.c
src/ggml-hexagon/htp/binary-ops.c
src/ggml-hexagon/htp/cpy-ops.c
src/ggml-hexagon/htp/get-rows-ops.c
src/ggml-hexagon/htp/hvx-arith.h
src/ggml-hexagon/htp/hvx-base.h
src/ggml-hexagon/htp/hvx-div.h
src/ggml-hexagon/htp/hvx-inverse.h
src/ggml-hexagon/htp/rope-ops.c
src/ggml-hexagon/htp/set-rows-ops.c
src/ggml-hexagon/htp/softmax-ops.c
src/ggml-hexagon/htp/sum-rows-ops.c
src/ggml-hexagon/htp/unary-ops.c

index 3006e217796ff01e2627d47bf1eb7473880d5fb5..b70da8f3b28a05cc986a3f19578f45e798afbe88 100644 (file)
@@ -1865,15 +1865,26 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
     const struct ggml_tensor * src1 = op->src[1];
     const struct ggml_tensor * dst  = op;
 
-    if (src0->type != GGML_TYPE_F32) {
-        return false;
+    if (src0->type == GGML_TYPE_F32) {
+        if (src1->type != GGML_TYPE_F32) {
+            return false;
+        }
+        if (dst->type != GGML_TYPE_F32) {
+            return false;
+        }
     }
-    if (src1->type != GGML_TYPE_F32) {
-        return false;
+    else if (src0->type == GGML_TYPE_F16) {
+        if (src1->type != GGML_TYPE_F16) {
+            return false;
+        }
+        if (dst->type != GGML_TYPE_F16) {
+            return false;
+        }
     }
-    if (dst->type != GGML_TYPE_F32) {
+    else {
         return false;
     }
+
     if (!ggml_are_same_shape(src0, dst)) {
         return false;
     }
index 21bd4050a1d2706966085554db42aa5d50411c22..d8b924981e0999d650aa03638f9d45b37e11c706 100644 (file)
@@ -693,8 +693,8 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
             return HTP_STATUS_NO_SUPPORT;
     }
 
-    const uint32_t n_threads  = octx->n_threads;
     const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);
 
     size_t src0_row_size = src0->nb[1];
     size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used
@@ -748,13 +748,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
         return HTP_STATUS_OK;
     }
 
-    uint32_t n_jobs = MIN(n_threads, src0_nrows);
-
     // Prepare context
     struct htp_act_context actx;
     actx.octx = octx;
 
-    actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+    actx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
 
     actx.src0_row_size = src0_row_size;
     actx.src1_row_size = src1_row_size;
@@ -794,7 +792,7 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
     actx.data_src1 = data_src1;
     actx.data_dst  = (uint8_t *) dst->data;
 
-    worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs);
+    worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_threads);
     return HTP_STATUS_OK;
 }
 
index a4cee980be8baecfec6935773c5b2c9a33d7c83b..170220e8f80c7132dde52176060fb41aada97bbc 100644 (file)
@@ -241,6 +241,9 @@ int op_argsort(struct htp_ops_context * octx) {
         return HTP_STATUS_NO_SUPPORT;
     }
 
+    const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
+    const uint32_t n_threads = MIN(total_rows, octx->n_threads);
+
     // Allocate scratchpad
     // We need 1 row of float + 1 row of int32 per thread.
     uint32_t ne00 = octx->src0.ne[0];
@@ -251,7 +254,7 @@ int op_argsort(struct htp_ops_context * octx) {
     // Make sure we round up to 256 for alignment requirements
     spad_per_thread = hex_round_up(spad_per_thread, 256);
 
-    size_t total_spad_size = spad_per_thread * octx->n_threads;
+    size_t total_spad_size = spad_per_thread * n_threads;
 
     if (octx->ctx->vtcm_size < total_spad_size) {
         FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
@@ -267,15 +270,12 @@ int op_argsort(struct htp_ops_context * octx) {
          octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
          octx->src0.data, octx->dst.data);
 
-    uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
-    uint32_t n_jobs = MIN(total_rows, octx->n_threads);
-
     struct htp_argsort_context actx;
     actx.octx = octx;
-    actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
+    actx.nrows_per_thread = (total_rows + n_threads - 1) / n_threads;
 
     // Run jobs
-    worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
+    worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_threads);
 
     return HTP_STATUS_OK;
 }
index 00dbcf87986ac60fefe6eac3da66dba1009dc891..ec90f22de52ee610503bf7ef41dba18984badeb6 100644 (file)
@@ -95,43 +95,87 @@ static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_
 }
 
 // Macro for scalar op switch
-#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \
-    switch (octx->op) { \
-        case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \
-        case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \
-        case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \
-        case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \
-        default: break; \
+#define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \
+    if(TYPE == HTP_TYPE_F32) { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
+            case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
+            case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
+            case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \
+            default: break; \
+        } \
+    } \
+    else { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+            case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+            case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+            case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+            default: break; \
+        } \
     }
 
 // Macro for vector op switch (All Aligned)
-#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \
-    switch (octx->op) { \
-        case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
-        case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
-        case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
-        case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
-        default: break; \
+#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \
+    if(TYPE == HTP_TYPE_F32) { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
+    } \
+    else { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
     }
 
 // Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
-#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \
-    switch (octx->op) { \
-        case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
-        case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
-        case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
-        case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
-        default: break; \
+#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \
+    if(TYPE == HTP_TYPE_F32) { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
+    } \
+    else { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
     }
 
 // Macro for vector op switch (All Unaligned - generic loop used in element repeat)
-#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \
-    switch (octx->op) { \
-        case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
-        case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
-        case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
-        case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
-        default: break; \
+#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \
+    if(TYPE == HTP_TYPE_F32) { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
+    } \
+    else { \
+        switch (octx->op) { \
+            case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \
+            case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \
+            default: break; \
+        } \
     }
 
 // 1. Scalar src1 (ne10 == 1)
@@ -140,6 +184,8 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
     struct htp_ops_context * octx = bctx->octx;
     htp_binary_preamble;
 
+    const uint32_t src0_type = octx->src0.type;
+    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
     const uint32_t total_rows = ne01 * ne02 * ne03;
     const uint32_t start_row = bctx->nrows_per_thread * ith;
     const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
@@ -170,7 +216,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
         uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
         dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
-        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
         ir_prefetch += current_block_size;
         spad_idx ^= 1;
     }
@@ -199,13 +245,12 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
         for (uint32_t r = 0; r < current_block_size; r++) {
             uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
             uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
-            float val = *(float *)src1_ptr;
+            COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00);
             src1_ptr += s1_stride;
-            COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00);
         }
 
         uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
-        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
 
         if (ir_prefetch < end_row) {
              uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
@@ -216,7 +261,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
              p01 = prem - p02 * ne01;
              uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
 
-             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
              ir_prefetch += next_block_size;
         }
         ir += current_block_size;
@@ -230,6 +275,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
     struct htp_ops_context * octx = bctx->octx;
     htp_binary_preamble;
 
+    const uint32_t src0_type = octx->src0.type;
+    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
     const uint32_t total_rows = ne01 * ne02 * ne03;
     const uint32_t start_row = bctx->nrows_per_thread * ith;
     const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
@@ -268,8 +315,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
         uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
         dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
-        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
-        dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
+        dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size);
         ir_prefetch += current_block_size;
         spad_idx ^= 1;
     }
@@ -284,7 +331,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
             uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
             uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
             uint8_t * r_dst  = d_spad  + r * bctx->dst_row_size_aligned;
-            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
+            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
         }
 
         uint32_t i03, i02, i01, rem;
@@ -293,7 +340,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
         i02 = fastdiv(rem, &bctx->dim1_div);
         i01 = rem - i02 * ne01;
         uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
-        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
 
         if (ir_prefetch < end_row) {
              uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
@@ -310,8 +357,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi
              uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
              uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
 
-             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
-             dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size);
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
+             dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size);
 
              ir_prefetch += next_block_size;
         }
@@ -326,6 +373,8 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
     struct htp_ops_context * octx = bctx->octx;
     htp_binary_preamble;
 
+    const uint32_t src0_type = octx->src0.type;
+    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
     const uint32_t total_rows = ne01 * ne02 * ne03;
     const uint32_t start_row = bctx->nrows_per_thread * ith;
     const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
@@ -359,7 +408,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
         uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
         dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
-        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
         ir_prefetch += current_block_size;
         spad_idx ^= 1;
     }
@@ -373,7 +422,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
             uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
             uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
             uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
-            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
+            COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
         }
 
         uint32_t i03, i02, i01, rem;
@@ -382,7 +431,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
         i02 = fastdiv(rem, &bctx->dim1_div);
         i01 = rem - i02 * ne01;
         uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
-        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
 
         if (ir_prefetch < end_row) {
              uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
@@ -392,7 +441,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith,
              p02 = fastdiv(prem, &bctx->dim1_div);
              p01 = prem - p02 * ne01;
              uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
-             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
              ir_prefetch += next_block_size;
         }
         ir += current_block_size;
@@ -406,6 +455,8 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
     struct htp_ops_context * octx = bctx->octx;
     htp_binary_preamble;
 
+    const uint32_t src0_type = octx->src0.type;
+    const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
     const uint32_t total_rows = ne01 * ne02 * ne03;
     const uint32_t start_row = bctx->nrows_per_thread * ith;
     const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
@@ -435,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
         uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
         dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
-        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
         ir_prefetch += current_block_size;
         spad_idx ^= 1;
     }
@@ -462,11 +513,11 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
             uint8_t * r_dst  = d_spad + r * bctx->dst_row_size_aligned;
 
             // Read src1 from DDR (unaligned)
-            COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00);
+            COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00);
         }
 
         uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
-        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
 
         if (ir_prefetch < end_row) {
              uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
@@ -476,7 +527,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void *
              p02 = fastdiv(prem, &bctx->dim1_div);
              p01 = prem - p02 * ne01;
              uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
-             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
              ir_prefetch += next_block_size;
         }
         ir += current_block_size;
@@ -490,6 +541,9 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
     struct htp_ops_context * octx = bctx->octx;
     htp_binary_preamble;
 
+    const uint32_t src0_type = octx->src0.type;
+    const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
+    const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
     const uint32_t total_rows = ne01 * ne02 * ne03;
     const uint32_t start_row = bctx->nrows_per_thread * ith;
     const uint32_t end_row   = MIN(start_row + bctx->nrows_per_thread, total_rows);
@@ -519,7 +573,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
         uint8_t * d_spad  = dst_spad_base  + spad_idx * dst_spad_half;
 
         dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
-        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+        dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
         ir_prefetch += current_block_size;
         spad_idx ^= 1;
     }
@@ -549,12 +603,12 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
             for (uint32_t c = 0; c < ne00; c += ne10) {
                 uint32_t len = MIN(ne10, ne00 - c);
                 // Use UUU for speed and simplicity
-                COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len);
+                COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len);
             }
         }
 
         uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
-        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+        dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
 
         if (ir_prefetch < end_row) {
              uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
@@ -564,7 +618,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void *
              p02 = fastdiv(prem, &bctx->dim1_div);
              p01 = prem - p02 * ne01;
              uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
-             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+             dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
              ir_prefetch += next_block_size;
         }
         ir += current_block_size;
@@ -672,18 +726,20 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
     dma_queue_flush(q);
 }
 
-static int execute_op_binary_f32(struct htp_ops_context * octx) {
+static int execute_op_binary(struct htp_ops_context * octx) {
     const struct htp_tensor * src0 = &octx->src0;
     const struct htp_tensor * src1 = &octx->src1;
     struct htp_tensor *       dst  = &octx->dst;
 
-    const uint32_t n_threads  = octx->n_threads;
     const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);
 
     // Use packed row sizes for VTCM allocation
-    const size_t src0_row_size = src0->ne[0] * sizeof(float);
-    const size_t src1_row_size = src1->ne[0] * sizeof(float);
-    const size_t dst_row_size  = dst->ne[0] * sizeof(float);
+    const uint32_t src0_type = octx->src0.type;
+    const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
+    const size_t src0_row_size = src0->ne[0] * elem_size;
+    const size_t src1_row_size = src1->ne[0] * elem_size;
+    const size_t dst_row_size  = dst->ne[0] * elem_size;
 
     // Align to VLEN
     const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
@@ -694,7 +750,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
     bool is_scalar = !is_add_id && (src1->ne[0] == 1);
 
     // Determine which kernel we will use to alloc memory and dispatch
-    bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] &&
+    bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) &&
                (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
                (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
                (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
@@ -726,7 +782,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
     }
 
     if (rows_per_buffer < 1) {
-         FARF(ERROR, "binary-f32: VTCM too small\n");
+         FARF(ERROR, "binary: VTCM too small\n");
          return HTP_STATUS_VTCM_TOO_SMALL;
     }
 
@@ -761,16 +817,14 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
         return HTP_STATUS_OK;
     }
 
-    uint32_t n_jobs = MIN(n_threads, src0_nrows);
-
     dma_queue * q = octx->ctx->dma[0];
     if (is_row_bcast) {
-        dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1);
+        dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1);
     }
 
     struct htp_binary_context bctx;
     bctx.octx = octx;
-    bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+    bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
     bctx.block_max = rows_per_buffer;
     bctx.src0_row_size_aligned = src0_row_size_aligned;
     bctx.src1_row_size_aligned = src1_row_size_aligned;
@@ -814,14 +868,24 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) {
         dma_queue_pop(q);
     }
 
-    worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs);
+    worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads);
 
     return HTP_STATUS_OK;
 }
 
 int op_binary(struct htp_ops_context * octx) {
-    if (octx->src0.type == HTP_TYPE_F32) {
-        return execute_op_binary_f32(octx);
+
+    // Does not support permutations of src1
+    const struct htp_tensor * src1 = &octx->src1;
+    if (src1->nb[1] < src1->nb[0]) {
+        return HTP_STATUS_NO_SUPPORT;
+    }
+
+    const uint32_t src0_type = octx->src0.type;
+    if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) {
+        return execute_op_binary(octx);
     }
+
     return HTP_STATUS_NO_SUPPORT;
 }
+
index 559ca18378993b7f2891ec0cd0855903791fcecf..a40d866b9c3b891b44bac1c9593b254632d96325 100644 (file)
@@ -202,6 +202,8 @@ static void cpy_work_func(unsigned int n, unsigned int i, void *data) {
 int op_cpy(struct htp_ops_context * octx) {
     cpy_preamble;
 
+    const uint32_t n_threads = MIN(nr, octx->n_threads);
+
     struct htp_copy_context ct;
     ct.octx = octx;
 
@@ -227,8 +229,7 @@ int op_cpy(struct htp_ops_context * octx) {
     const bool transposed = (nb00 > nb01) || (nb0 > nb1);
     const bool sameshape  = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3);
 
-    const uint32_t n_jobs = MIN(nr, octx->n_threads);
-    ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+    ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
 
     if (sametype && sameshape) {
         ct.copy = cpy_thread_sametype_sameshape;
@@ -245,7 +246,7 @@ int op_cpy(struct htp_ops_context * octx) {
         return HTP_STATUS_NO_SUPPORT;
     }
 
-    worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs);
+    worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads);
 
     return HTP_STATUS_OK;
 }
index bf24bbda70ae2dfb14b8203b84fa1427d5909ba2..047d2850aaa9900afae2c8a9d870d9702ed10a28 100644 (file)
@@ -82,6 +82,8 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da
 int op_get_rows(struct htp_ops_context * octx) {
     get_rows_preamble;
 
+    const uint32_t n_threads = MIN(nr, octx->n_threads);
+
     if (octx->src0.type != HTP_TYPE_F32) {
         return HTP_STATUS_NO_SUPPORT;
     }
@@ -103,9 +105,8 @@ int op_get_rows(struct htp_ops_context * octx) {
     grctx.get_rows_div_ne10      = init_fastdiv_values(octx->src1.ne[0]);
     grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
 
-    const uint32_t n_jobs = MIN(nr, octx->n_threads);
-    grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+    grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads;
 
-    worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs);
+    worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads);
     return HTP_STATUS_OK;
 }
index 2577cdd0418c6e5188739fc06563cdcf54b783ce..82e3416970b4f33754645b97968b5f9e284896d0 100644 (file)
 // Binary operations (add, mul, sub)
 //
 
-#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \
+#define UNUSED(x) (void)(x)
+
+#define hvx_arith_loop_body(dst_type, src0_type, src1_type, elem_size, vec_store, vec_op) \
     do {                                                                       \
         dst_type * restrict vdst  = (dst_type *) dst;                          \
         src0_type * restrict vsrc0 = (src0_type *) src0;                       \
         src1_type * restrict vsrc1 = (src1_type *) src1;                       \
                                                                                \
-        const uint32_t elem_size = sizeof(float);                              \
-        const uint32_t epv  = 128 / elem_size;                                 \
+        const uint32_t epv  = 128 / (elem_size);                               \
         const uint32_t nvec = n / epv;                                         \
         const uint32_t nloe = n % epv;                                         \
                                                                                \
         }                                                                      \
         if (nloe) {                                                            \
             HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]);                         \
-            vec_store((void *) &vdst[i], nloe * elem_size, v);                 \
+            vec_store((void *) &vdst[i], nloe * (elem_size), v);               \
         }                                                                      \
     } while(0)
 
 #if __HVX_ARCH__ < 79
-#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
-#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
-#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+
+#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
+#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+
 #else
-#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
-#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
-#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+
+#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
+#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+
 #endif
 
+#define HVX_OP_ADD_F16(a, b) hvx_vec_add_f16_f16(a, b)
+#define HVX_OP_SUB_F16(a, b) hvx_vec_sub_f16_f16(a, b)
+#define HVX_OP_MUL_F16(a, b) hvx_vec_mul_f16_f16(a, b)
+
 // Generic macro to define alignment permutations for an op
-#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
+#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO, ELEM_TYPE) \
 static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
     assert((uintptr_t) dst % 128 == 0); \
     assert((uintptr_t) src0 % 128 == 0); \
     assert((uintptr_t) src1 % 128 == 0); \
-    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
+    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
 } \
 static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
     assert((uintptr_t) dst % 128 == 0); \
     assert((uintptr_t) src0 % 128 == 0); \
-    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
+    hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
 } \
 static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
     assert((uintptr_t) dst % 128 == 0); \
     assert((uintptr_t) src1 % 128 == 0); \
-    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
+    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
 } \
 static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
     assert((uintptr_t) dst % 128 == 0); \
-    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
+    hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
 } \
 static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
     assert((uintptr_t) src0 % 128 == 0); \
     assert((uintptr_t) src1 % 128 == 0); \
-    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
+    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
 } \
 static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
     assert((uintptr_t) src0 % 128 == 0); \
-    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
+    hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
 } \
 static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
     assert((uintptr_t) src1 % 128 == 0); \
-    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
+    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
 } \
 static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
-    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
+    hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
 } \
 
-DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
-DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
-DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD_F32, float)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB_F32, float)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL_F32, float)
+
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f16, HVX_OP_ADD_F16, _Float16)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f16, HVX_OP_SUB_F16, _Float16)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f16, HVX_OP_MUL_F16, _Float16)
 
 // Dispatcher logic
 #define HVX_BINARY_DISPATCHER(OP_NAME) \
@@ -115,6 +128,10 @@ HVX_BINARY_DISPATCHER(hvx_add_f32)
 HVX_BINARY_DISPATCHER(hvx_sub_f32)
 HVX_BINARY_DISPATCHER(hvx_mul_f32)
 
+HVX_BINARY_DISPATCHER(hvx_add_f16)
+HVX_BINARY_DISPATCHER(hvx_sub_f16)
+HVX_BINARY_DISPATCHER(hvx_mul_f16)
+
 // Mul-Mul Optimized
 static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
     assert((unsigned long) dst % 128 == 0);
@@ -136,26 +153,25 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re
 
     _Pragma("unroll(4)")
     for (; i < nvec; i++) {
-        HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
+        HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);
         vdst[i] = HVX_OP_MUL(v1, vsrc2[i]);
     }
 
     if (nloe) {
-        HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
-        HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]);
+        HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);
+        HVX_Vector v2 = HVX_OP_MUL_F32(v1, vsrc2[i]);
         hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2);
     }
 }
 
 // Scalar Operations
 
-#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro)   \
+#define hvx_scalar_loop_body(dst_type, src_type, elem_size, vec_store, scalar_op_macro)   \
     do {                                                                       \
         dst_type * restrict vdst = (dst_type *) dst;                           \
         src_type * restrict vsrc = (src_type *) src;                           \
                                                                                \
-        const uint32_t elem_size = sizeof(float);                              \
-        const uint32_t epv  = 128 / elem_size;                                 \
+        const uint32_t epv  = 128 / (elem_size);                               \
         const uint32_t nvec = n / epv;                                         \
         const uint32_t nloe = n % epv;                                         \
                                                                                \
@@ -169,138 +185,88 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re
         if (nloe) {                                                            \
             HVX_Vector v = vsrc[i];                                            \
             v = scalar_op_macro(v);                                            \
-            vec_store((void *) &vdst[i], nloe * elem_size, v);                 \
+            vec_store((void *) &vdst[i], nloe * (elem_size), v);               \
         }                                                                      \
     } while(0)
 
-#define HVX_OP_ADD_SCALAR(v) \
+#define HVX_OP_ADD_SCALAR_F32(v) \
     ({ \
         const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \
-        HVX_Vector out = HVX_OP_ADD(v, val_vec); \
+        HVX_Vector out = HVX_OP_ADD_F32(v, val_vec); \
         Q6_V_vmux_QVV(pred_inf, inf, out); \
     })
 
-#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec)
-#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec)
-
-// Add Scalar Variants
-
-static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src % 128 == 0);
-    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
-}
-
-static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
-    assert((unsigned long) dst % 128 == 0);
-    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
-}
-
-static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
-    assert((unsigned long) src % 128 == 0);
-    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
-}
-
-static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    static const float kInf = INFINITY;
-    const HVX_Vector inf = hvx_vec_splat_f32(kInf);
-    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
-}
-
-// Sub Scalar Variants
+#define HVX_OP_MUL_SCALAR_F32(v) HVX_OP_MUL_F32(v, val_vec)
+#define HVX_OP_SUB_SCALAR_F32(v) HVX_OP_SUB_F32(v, val_vec)
 
-static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src % 128 == 0);
-    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
-}
-
-static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    assert((unsigned long) dst % 128 == 0);
-    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
-}
-
-static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    assert((unsigned long) src % 128 == 0);
-    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
-}
-
-static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
-}
+#define HVX_OP_ADD_SCALAR_F16(v) \
+    ({ \
+        const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VhVh(inf, v); \
+        HVX_Vector out = HVX_OP_ADD_F16(v, val_vec); \
+        Q6_V_vmux_QVV(pred_inf, inf, out); \
+    })
 
-// Mul Scalar Variants
+#define HVX_OP_MUL_SCALAR_F16(v) HVX_OP_MUL_F16(v, val_vec)
+#define HVX_OP_SUB_SCALAR_F16(v) HVX_OP_SUB_F16(v, val_vec)
 
-static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src % 128 == 0);
-    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
-}
+// Scalar Variants
 
-static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    assert((unsigned long) dst % 128 == 0);
-    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
-}
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(OP_NAME, OP_MACRO, SPLAT_MACRO, ELEM_TYPE) \
+static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+    const HVX_Vector val_vec = SPLAT_MACRO(val); \
+    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src % 128 == 0); \
+    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+    const HVX_Vector val_vec = SPLAT_MACRO(val); \
+    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+    assert((uintptr_t) dst % 128 == 0); \
+    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+    const HVX_Vector val_vec = SPLAT_MACRO(val); \
+    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+    assert((uintptr_t) src % 128 == 0); \
+    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+    const HVX_Vector val_vec = SPLAT_MACRO(val); \
+    const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
+} \
 
-static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    assert((unsigned long) src % 128 == 0);
-    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
-}
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f32, HVX_OP_ADD_SCALAR_F32, hvx_vec_splat_f32, float)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f32, HVX_OP_SUB_SCALAR_F32, hvx_vec_splat_f32, float)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f32, HVX_OP_MUL_SCALAR_F32, hvx_vec_splat_f32, float)
 
-static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
-    const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
-}
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f16, HVX_OP_ADD_SCALAR_F16, hvx_vec_splat_f16, _Float16)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f16, HVX_OP_SUB_SCALAR_F16, hvx_vec_splat_f16, _Float16)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f16, HVX_OP_MUL_SCALAR_F16, hvx_vec_splat_f16, _Float16)
 
-static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
-    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
-        hvx_add_scalar_f32_aa(dst, src, val, num_elems);
-    } else if (hex_is_aligned((void *) dst, 128)) {
-        hvx_add_scalar_f32_au(dst, src, val, num_elems);
-    } else if (hex_is_aligned((void *) src, 128)) {
-        hvx_add_scalar_f32_ua(dst, src, val, num_elems);
-    } else {
-        hvx_add_scalar_f32_uu(dst, src, val, num_elems);
-    }
+// Dispatcher logic
+#define HVX_BINARY_SCALAR_DISPATCHER(OP_NAME, ELEM_TYPE) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, const uint32_t num_elems) { \
+    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \
+        OP_NAME##_aa(dst, src, val, num_elems); \
+    } else if (hex_is_aligned((void *) dst, 128)) { \
+        OP_NAME##_au(dst, src, val, num_elems); \
+    } else if (hex_is_aligned((void *) src, 128)) { \
+        OP_NAME##_ua(dst, src, val, num_elems); \
+    } else { \
+        OP_NAME##_uu(dst, src, val, num_elems); \
+    } \
 }
 
-static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
-    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
-        hvx_mul_scalar_f32_aa(dst, src, val, num_elems);
-    } else if (hex_is_aligned((void *) dst, 128)) {
-        hvx_mul_scalar_f32_au(dst, src, val, num_elems);
-    } else if (hex_is_aligned((void *) src, 128)) {
-        hvx_mul_scalar_f32_ua(dst, src, val, num_elems);
-    } else {
-        hvx_mul_scalar_f32_uu(dst, src, val, num_elems);
-    }
-}
+HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f32, float)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f32, float)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f32, float)
 
-static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
-    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
-        hvx_sub_scalar_f32_aa(dst, src, val, num_elems);
-    } else if (hex_is_aligned((void *) dst, 128)) {
-        hvx_sub_scalar_f32_au(dst, src, val, num_elems);
-    } else if (hex_is_aligned((void *) src, 128)) {
-        hvx_sub_scalar_f32_ua(dst, src, val, num_elems);
-    } else {
-        hvx_sub_scalar_f32_uu(dst, src, val, num_elems);
-    }
-}
+HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f16, _Float16)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f16, _Float16)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f16, _Float16)
 
 // MIN Scalar variants
 
@@ -310,24 +276,24 @@ static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t *
     const HVX_Vector val_vec = hvx_vec_splat_f32(val);
     assert((unsigned long) dst % 128 == 0);
     assert((unsigned long) src % 128 == 0);
-    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
+    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);
 }
 
 static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
     const HVX_Vector val_vec = hvx_vec_splat_f32(val);
     assert((unsigned long) dst % 128 == 0);
-    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
+    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);
 }
 
 static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
     const HVX_Vector val_vec = hvx_vec_splat_f32(val);
     assert((unsigned long) src % 128 == 0);
-    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
+    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);
 }
 
 static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
     const HVX_Vector val_vec = hvx_vec_splat_f32(val);
-    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
+    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);
 }
 
 static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
@@ -357,27 +323,27 @@ static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t
     const HVX_Vector max_vec = hvx_vec_splat_f32(max);
     assert((unsigned long) dst % 128 == 0);
     assert((unsigned long) src % 128 == 0);
-    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
+    hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
 }
 
 static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
     const HVX_Vector min_vec = hvx_vec_splat_f32(min);
     const HVX_Vector max_vec = hvx_vec_splat_f32(max);
     assert((unsigned long) dst % 128 == 0);
-    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
+    hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
 }
 
 static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
     const HVX_Vector min_vec = hvx_vec_splat_f32(min);
     const HVX_Vector max_vec = hvx_vec_splat_f32(max);
     assert((unsigned long) src % 128 == 0);
-    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
+    hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
 }
 
 static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
     const HVX_Vector min_vec = hvx_vec_splat_f32(min);
     const HVX_Vector max_vec = hvx_vec_splat_f32(max);
-    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
+    hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
 }
 
 static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) {
@@ -396,7 +362,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
 // Square
 //
 
-#define hvx_sqr_loop_body(dst_type, src_type, vec_store)           \
+#define hvx_sqr_f32_loop_body(dst_type, src_type, vec_store)           \
     do {                                                                   \
         dst_type * restrict vdst  = (dst_type *) dst;                      \
         src_type * restrict vsrc = (src_type *) src;                       \
@@ -410,10 +376,10 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
                                                                            \
         _Pragma("unroll(4)")                                               \
         for (; i < nvec; i++) {                                            \
-            vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]);                        \
+            vdst[i] = HVX_OP_MUL_F32(vsrc[i], vsrc[i]);                        \
         }                                                                  \
         if (nloe) {                                                        \
-            HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]);                   \
+            HVX_Vector v = HVX_OP_MUL_F32(vsrc[i], vsrc[i]);                   \
             vec_store((void *) &vdst[i], nloe * elem_size, v);             \
         }                                                                  \
     } while(0)
@@ -421,21 +387,21 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
 static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
     assert((unsigned long) dst % 128 == 0);
     assert((unsigned long) src % 128 == 0);
-    hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+    hvx_sqr_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
 }
 
 static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
     assert((unsigned long) dst % 128 == 0);
-    hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+    hvx_sqr_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
 }
 
 static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
     assert((unsigned long) src % 128 == 0);
-    hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+    hvx_sqr_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
 }
 
 static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+    hvx_sqr_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
 }
 
 static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
@@ -454,17 +420,24 @@ static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict
     }
 }
 
-#undef HVX_OP_ADD
-#undef HVX_OP_SUB
-#undef HVX_OP_MUL
+#undef HVX_OP_ADD_F32
+#undef HVX_OP_SUB_F32
+#undef HVX_OP_MUL_F32
+#undef HVX_OP_ADD_F16
+#undef HVX_OP_SUB_F16
+#undef HVX_OP_MUL_F16
 #undef hvx_arith_loop_body
-#undef HVX_OP_ADD_SCALAR
-#undef HVX_OP_SUB_SCALAR
-#undef HVX_OP_MUL_SCALAR
+#undef HVX_OP_ADD_SCALAR_F32
+#undef HVX_OP_SUB_SCALAR_F32
+#undef HVX_OP_MUL_SCALAR_F32
+#undef HVX_OP_ADD_SCALAR_F16
+#undef HVX_OP_SUB_SCALAR_F16
+#undef HVX_OP_MUL_SCALAR_F16
 #undef hvx_scalar_loop_body
 #undef HVX_OP_MIN_SCALAR
 #undef HVX_OP_CLAMP_SCALAR
 #undef DEFINE_HVX_BINARY_OP_VARIANTS
 #undef HVX_BINARY_DISPATCHER
+#undef UNUSED
 
 #endif // HVX_ARITH_H
index 701637f22b2ea15537ae43dd8e210e219ae655a3..578ca288fb65f288c53de83fab24f70957de7f35 100644 (file)
@@ -189,4 +189,52 @@ static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vect
 
 #endif
 
+#if __HVX_ARCH__ < 79
+
+static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16
+    const HVX_Vector one    = Q6_Vh_vsplat_R(0x3C00); //  1.0 in IEEE FP16
+    HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);
+    HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);
+    HVX_Vector a0 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));
+    HVX_Vector a1 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));
+    return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));
+}
+
+static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16
+    const HVX_Vector one    = Q6_Vh_vsplat_R(0x3C00); //  1.0 in IEEE FP16
+    HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);
+    HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);
+    HVX_Vector a0 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));
+    HVX_Vector a1 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));
+    return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));
+}
+
+static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b));
+}
+
+#else
+
+static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    return Q6_Vhf_vadd_VhfVhf(a, b);
+}
+
+static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    return Q6_Vhf_vsub_VhfVhf(a, b);
+}
+
+static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+    return Q6_Vhf_vmpy_VhfVhf(a, b);
+}
+
+#endif // __HVX_ARCH__ < 79
+
 #endif /* HVX_BASE_H */
index 7dae012e0ed07bdd2c06f831863af9b18b315c58..05cefea039f6fd91e05265c2b77525ed4a43732b 100644 (file)
 #include "hvx-arith.h"
 
 #if __HVX_ARCH__ < 79
-#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
 #else
-#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
 #endif
 
+// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.
+static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX_Vector vec2_sf_const, HVX_Vector vec_hf_one_1_0) {
+#if __HVX_ARCH__ < 79
+    HVX_VectorPair src_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);
+    HVX_Vector src_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(src_to_f32));
+    HVX_Vector src_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(src_to_f32));
+#else
+    HVX_VectorPair src_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);
+    HVX_Vector src_to_f32_0 = Q6_V_lo_W(src_to_f32);
+    HVX_Vector src_to_f32_1 = Q6_V_hi_W(src_to_f32);
+#endif
+
+    HVX_Vector div_f32_0 = HVX_OP_MUL_F32(src_to_f32_0, vec2_sf_const);
+    HVX_Vector div_f32_1 = HVX_OP_MUL_F32(src_to_f32_1, vec2_sf_const);
+
+#if __HVX_ARCH__ < 79
+    HVX_Vector res = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);
+#else
+    HVX_Vector res = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);
+#endif
+    return res;
+}
+
+#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store)                     \
+    do {                                                                                \
+        dst_type * restrict vdst = (dst_type *) dst;                                    \
+        src_type * restrict vsrc = (src_type *) src;                                    \
+        HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00);                                     \
+                                                                                        \
+        const uint32_t nvec = n / VLEN_FP16;                                            \
+        const uint32_t nloe = n % VLEN_FP16;                                            \
+                                                                                        \
+        uint32_t i = 0;                                                                 \
+                                                                                        \
+        _Pragma("unroll(4)")                                                            \
+        for (; i < nvec; i++) {                                                         \
+            HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
+            vdst[i] = res;                                                              \
+        }                                                                               \
+        if (nloe) {                                                                     \
+            HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res);                      \
+        }                                                                               \
+    } while(0)
+
+static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    assert((uintptr_t) dst % 128 == 0);
+    assert((uintptr_t) src % 128 == 0);
+    hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    assert((uintptr_t) dst % 128 == 0);
+    hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    assert((uintptr_t) src % 128 == 0);
+    hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+    const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+// Compute div by using hvx_vec_inverse_f32_guard. Requires first by exapnding fp32 to fp16 and convert the result back to fp32.
+static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector vec_hf_one_1_0) {
+#if __HVX_ARCH__ < 79
+    // Convert first input to fp32
+    HVX_VectorPair vec1_to_f32   = Q6_Wqf32_vmpy_VhfVhf(vec1, vec_hf_one_1_0);  // *1.0
+    HVX_Vector     vec1_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec1_to_f32));
+    HVX_Vector     vec1_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec1_to_f32));
+
+    // Convert second input to fp32
+    HVX_VectorPair vec2_to_f32   = Q6_Wqf32_vmpy_VhfVhf(vec2, vec_hf_one_1_0);  // *1.0
+    HVX_Vector     vec2_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec2_to_f32));
+    HVX_Vector     vec2_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec2_to_f32));
+#else
+    // Convert first input to fp32
+    HVX_VectorPair vec1_to_f32   = Q6_Wsf_vmpy_VhfVhf(vec1, vec_hf_one_1_0);  // *1.0
+    HVX_Vector     vec1_to_f32_0 = Q6_V_lo_W(vec1_to_f32);
+    HVX_Vector     vec1_to_f32_1 = Q6_V_hi_W(vec1_to_f32);
+
+    // Convert second input to fp32
+    HVX_VectorPair vec2_to_f32   = Q6_Wsf_vmpy_VhfVhf(vec2, vec_hf_one_1_0);  // *1.0
+    HVX_Vector     vec2_to_f32_0 = Q6_V_lo_W(vec2_to_f32);
+    HVX_Vector     vec2_to_f32_1 = Q6_V_hi_W(vec2_to_f32);
+#endif
+
+    // Inverse second input in fp32
+    HVX_Vector     vec2_inv_f32_0 = hvx_vec_inverse_f32_guard(vec2_to_f32_0, f32_nan_inf_mask);
+    HVX_Vector     vec2_inv_f32_1 = hvx_vec_inverse_f32_guard(vec2_to_f32_1, f32_nan_inf_mask);
+
+    // Multiply first input by inverse of second, in fp32
+    HVX_Vector     div_f32_0 = HVX_OP_MUL_F32(vec1_to_f32_0, vec2_inv_f32_0);
+    HVX_Vector     div_f32_1 = HVX_OP_MUL_F32(vec1_to_f32_1, vec2_inv_f32_1);
+
+    // Convert back to fp16
+#if __HVX_ARCH__ < 79
+    HVX_Vector     recip = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);
+#else
+    HVX_Vector     recip = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);
+#endif
+
+    return recip;
+}
+
+#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store)                  \
+    do {                                                                                  \
+        dst_type * restrict vdst = (dst_type *) dst;                                      \
+        src0_type * restrict vsrc0 = (src0_type *) src0;                                  \
+        src1_type * restrict vsrc1 = (src1_type *) src1;                                  \
+                                                                                          \
+        const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000);                        \
+        const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00);                                 \
+                                                                                          \
+        const uint32_t nvec = n / VLEN_FP16;                                              \
+        const uint32_t nloe = n % VLEN_FP16;                                              \
+                                                                                          \
+        uint32_t i = 0;                                                                   \
+                                                                                          \
+        _Pragma("unroll(4)")                                                              \
+        for (; i < nvec; i++) {                                                           \
+            HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
+            vdst[i] = res;                                                                \
+        }                                                                                 \
+        if (nloe) {                                                                       \
+            HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res);                        \
+        }                                                                                 \
+    } while(0)
+
 #define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store)             \
     do {                                                                             \
         dst_type * restrict vdst = (dst_type *) dst;                                 \
         _Pragma("unroll(4)")                                                         \
         for (; i < nvec; i++) {                                                      \
             HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
-            HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1);                         \
+            HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1);                     \
             vdst[i] = res;                                                           \
         }                                                                            \
         if (nloe) {                                                                  \
             HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
-            HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1);                         \
+            HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1);                     \
             vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res);                   \
         }                                                                            \
     } while(0)
 
-// 3-letter suffix variants
-static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((uintptr_t) dst % 128 == 0);
-    assert((uintptr_t) src0 % 128 == 0);
-    assert((uintptr_t) src1 % 128 == 0);
-    hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
-}
-
-static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((uintptr_t) dst % 128 == 0);
-    assert((uintptr_t) src0 % 128 == 0);
-    hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_DIV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \
+static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src0 % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src0 % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src0 % 128 == 0); \
+    assert((uintptr_t) src1 % 128 == 0); \
+    OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src0 % 128 == 0); \
+    OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    assert((uintptr_t) src1 % 128 == 0); \
+    OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+    OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); \
+} \
+
+// Dispatcher logic
+#define HVX_DIV_DISPATCHER(OP_NAME) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
+    if (hex_is_aligned((void *) dst, 128)) { \
+        if (hex_is_aligned((void *) src0, 128)) { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_aau(dst, src0, src1, num_elems); \
+        } else { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_auu(dst, src0, src1, num_elems); \
+        } \
+    } else { \
+        if (hex_is_aligned((void *) src0, 128)) { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_uau(dst, src0, src1, num_elems); \
+        } else { \
+            if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
+            else                                    OP_NAME##_uuu(dst, src0, src1, num_elems); \
+        } \
+    } \
 }
 
-static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((uintptr_t) dst % 128 == 0);
-    assert((uintptr_t) src1 % 128 == 0);
-    hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
-}
+DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f32, hvx_div_f32_loop_body)
+DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f16, hvx_div_f16_loop_body)
 
-static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((uintptr_t) dst % 128 == 0);
-    hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
-}
-
-static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((uintptr_t) src0 % 128 == 0);
-    assert((uintptr_t) src1 % 128 == 0);
-    hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
-}
-
-static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((uintptr_t) src0 % 128 == 0);
-    hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
-}
-
-static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    assert((uintptr_t) src1 % 128 == 0);
-    hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
-}
-
-static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
-    hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
-}
-
-static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
-    if (hex_is_aligned((void *) dst, 128)) {
-        if (hex_is_aligned((void *) src0, 128)) {
-            if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
-            else                                    hvx_div_f32_aau(dst, src0, src1, num_elems);
-        } else {
-            if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
-            else                                    hvx_div_f32_auu(dst, src0, src1, num_elems);
-        }
-    } else {
-        if (hex_is_aligned((void *) src0, 128)) {
-            if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
-            else                                    hvx_div_f32_uau(dst, src0, src1, num_elems);
-        } else {
-            if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
-            else                                    hvx_div_f32_uuu(dst, src0, src1, num_elems);
-        }
-    }
-}
+HVX_DIV_DISPATCHER(hvx_div_f32)
+HVX_DIV_DISPATCHER(hvx_div_f16)
 
-#undef HVX_OP_MUL
+#undef HVX_OP_MUL_F32
 
 #endif // HVX_DIV_H
index 53db94aae2bf764f3f85b91635d169a9ebef8185..f2054f45baca7b4e60f372049f882e6e3cd6242c 100644 (file)
@@ -137,40 +137,74 @@ static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector n
         }                                                                    \
     } while(0)
 
-static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    assert((unsigned long) dst % 128 == 0);
-    assert((unsigned long) src % 128 == 0);
-    hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
-}
+static inline HVX_Vector hvx_vec_inverse_f16_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
+    HVX_Vector out = hvx_vec_inverse_f16(v_sf);
 
-static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    assert((unsigned long) dst % 128 == 0);
-    hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
-}
+    HVX_Vector     masked_out = Q6_V_vand_VV(out, nan_inf_mask);
+    const HVX_VectorPred pred = Q6_Q_vcmp_eq_VhVh(nan_inf_mask, masked_out);
 
-static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    assert((unsigned long) src % 128 == 0);
-    hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+    return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
 }
 
-static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
-    hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
-}
+#define hvx_inverse_f16_loop_body(dst_type, src_type, vec_store)             \
+    do {                                                                     \
+        dst_type * restrict vdst = (dst_type *) dst;                         \
+        src_type * restrict vsrc = (src_type *) src;                         \
+                                                                             \
+        const HVX_Vector nan_inf_mask = Q6_Vh_vsplat_R(0x7c00);              \
+                                                                             \
+        const uint32_t nvec = n / VLEN_FP16;                                 \
+        const uint32_t nloe = n % VLEN_FP16;                                 \
+                                                                             \
+        uint32_t i = 0;                                                      \
+                                                                             \
+        _Pragma("unroll(4)")                                                 \
+        for (; i < nvec; i++) {                                              \
+             vdst[i] = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask);     \
+        }                                                                    \
+        if (nloe) {                                                          \
+            HVX_Vector v = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, v);             \
+        }                                                                    \
+    } while(0)
 
-static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) {
-    if ((unsigned long) dst % 128 == 0) {
-        if ((unsigned long) src % 128 == 0) {
-            hvx_inverse_f32_aa(dst, src, num_elems);
-        } else {
-            hvx_inverse_f32_au(dst, src, num_elems);
-        }
-    } else {
-        if ((unsigned long) src % 128 == 0) {
-            hvx_inverse_f32_ua(dst, src, num_elems);
-        } else {
-            hvx_inverse_f32_uu(dst, src, num_elems);
-        }
-    }
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_INV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \
+static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    assert((uintptr_t) src % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_Vector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+    assert((uintptr_t) dst % 128 == 0); \
+    OP_LOOP_BODY(HVX_Vector, HVX_UVector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+    assert((uintptr_t) src % 128 == 0); \
+    OP_LOOP_BODY(HVX_UVector, HVX_Vector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+    OP_LOOP_BODY(HVX_UVector, HVX_UVector, hvx_vec_store_u); \
+} \
+
+// Dispatcher logic
+#define HVX_INV_DISPATCHER(OP_NAME) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { \
+    if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \
+        OP_NAME##_aa(dst, src, num_elems); \
+    } else if (hex_is_aligned((void *) dst, 128)) { \
+        OP_NAME##_au(dst, src, num_elems); \
+    } else if (hex_is_aligned((void *) src, 128)) { \
+        OP_NAME##_ua(dst, src, num_elems); \
+    } else { \
+        OP_NAME##_uu(dst, src, num_elems); \
+    } \
 }
 
+DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f32, hvx_inverse_f32_loop_body)
+DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f16, hvx_inverse_f16_loop_body)
+
+HVX_INV_DISPATCHER(hvx_inverse_f32)
+HVX_INV_DISPATCHER(hvx_inverse_f16)
+
 #endif // HVX_INVERSE_H
index 9aeb80d0b8b7ae0fdcda4a628b4e226034a4ea3e..be9469538f63a1ca2ef8302b2cc7b4c15d4638c9 100644 (file)
@@ -400,7 +400,9 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
             return HTP_STATUS_NO_SUPPORT;
     }
 
-    const uint32_t n_threads = octx->n_threads;
+    const uint32_t ne0 = dst->ne[0];
+    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
 
     const size_t src0_row_size = src0->nb[1];
     const size_t dst_row_size  = dst->nb[1];
@@ -465,17 +467,14 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
     rctx.dst_row_size_aligned  = dst_row_size_aligned;
     rctx.theta_cache_offset    = theta_cache_size_aligned;
 
-    uint32_t ne0 = dst->ne[0];
-    uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
     rctx.src0_nrows = src0_nrows;
+    rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
 
     FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
          rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
 
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        uint32_t n_jobs = MIN(n_threads, src0_nrows);
-        rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_threads);
     }
 
     return err;
index 2fd6c907724407f726363599c21de58019034412..4b6967749f86f426b1869ac884348254d33c18a8 100644 (file)
@@ -128,6 +128,8 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da
 int op_set_rows(struct htp_ops_context * octx) {
     set_rows_preamble;
 
+    const uint32_t n_threads = MIN(nr, octx->n_threads);
+
     if (octx->src0.type != HTP_TYPE_F32) {
         return HTP_STATUS_NO_SUPPORT;
     }
@@ -149,15 +151,14 @@ int op_set_rows(struct htp_ops_context * octx) {
     srctx.div_ne12 = init_fastdiv_values(ne12);
     srctx.div_ne11 = init_fastdiv_values(ne11);
 
-    const uint32_t n_jobs = MIN(nr, octx->n_threads);
-    srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+    srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
 
     switch(octx->dst.type) {
     case HTP_TYPE_F32:
-        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads);
         break;
     case HTP_TYPE_F16:
-        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_threads);
         break;
     default:
         return HTP_STATUS_NO_SUPPORT;
index 6e22eb6a63991727a4f3bdb18955cbb06d098273..8dae7f1ed55391e8babb0e953ce3f513a08cc6af 100644 (file)
@@ -353,7 +353,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
             return HTP_STATUS_NO_SUPPORT;
     }
 
-    const uint32_t n_threads = octx->n_threads;
+    const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);
 
     const size_t src0_row_size = src0->nb[1];
     const size_t src1_row_size = src0_row_size;
@@ -393,12 +394,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
     octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
     octx->dst_spad.data  = octx->src1_spad.data + octx->src1_spad.size;
 
-    uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
-
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        uint32_t n_jobs             = MIN(n_threads, src0_nrows);
-        smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
-        worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs);
+        smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
+        worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads);
     }
 
     return err;
index 04fa72182a38e8636a6a3a5d2bb8f113f4e13d05..352650b689b1b3db62b284fb1e1a94ec1b1a5adb 100644 (file)
@@ -102,11 +102,9 @@ int op_sum_rows(struct htp_ops_context * octx) {
         return HTP_STATUS_OK;
     }
 
-    const int      n_threads  = octx->n_threads;
     const uint32_t src0_nrows = ne01 * ne02 * ne03;
-
-    uint32_t n_jobs = MIN(n_threads, src0_nrows);
-    uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+    const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
+    const uint32_t rows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
 
     bool opt_path = false;
     if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
@@ -124,7 +122,7 @@ int op_sum_rows(struct htp_ops_context * octx) {
         .opt_path        = opt_path,
     };
 
-    worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs);
+    worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_threads);
 
     return HTP_STATUS_OK;
 }
index 98135c50ab8da31aacdc7c875d334a43cfdeee2d..5bbd5040d3dea8b53f5fffa67c22f8d7b70aa7b8 100644 (file)
@@ -301,8 +301,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
             return HTP_STATUS_NO_SUPPORT;
     }
 
-    const int      n_threads  = octx->n_threads;
     const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+    const uint32_t n_threads  = MIN(octx->n_threads, src0_nrows);
 
     const size_t src0_row_size = src0->nb[1];
     const size_t dst_row_size  = dst->nb[1];
@@ -338,11 +338,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
          octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
 
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
-        uint32_t n_jobs = MIN(n_threads, src0_nrows);
-
         struct htp_unary_context uctx = {
             .octx                  = octx,
-            .src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs,
+            .src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads,
             .src0_nrows            = src0_nrows,
 
             .data_src0             = (const uint8_t *)src0->data,
@@ -361,7 +359,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
             .nc                    = src0->ne[0],
         };
 
-        worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs);
+        worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads);
     }
 
     return err;