]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
hexagon: add neg, exp, sigmoid, softplus ops, cont, repeat ops (llama/20701)
authorKrishna Sridhar <redacted>
Tue, 17 Mar 2026 22:34:36 +0000 (15:34 -0700)
committerGeorgi Gerganov <redacted>
Sat, 28 Mar 2026 11:39:09 +0000 (13:39 +0200)
Add element-wise unary ops needed by Qwen 3.5's DeltaNet linear
attention layers. These ops follow the existing unary-ops pattern
with VTCM DMA double-buffering.

- neg: negate via scale by -1.0
- exp: uses existing hvx_exp_f32 HVX intrinsics
- sigmoid: uses existing hvx_sigmoid_f32_aa HVX intrinsics
- softplus: log(1 + exp(x)) scalar fallback
- CONT reuses the existing CPY infrastructure since making a tensor
  contiguous is equivalent to a same-type copy.
- REPEAT implements tiled memory copy with multi-threaded execution via
  the worker pool, supporting f32 and f16 types. The kernel parallelizes
  across output rows and uses memcpy for each tile.

Co-authored-by: Max Krasnyansky <redacted>
src/ggml-hexagon/ggml-hexagon.cpp
src/ggml-hexagon/htp/CMakeLists.txt
src/ggml-hexagon/htp/htp-msg.h
src/ggml-hexagon/htp/htp-ops.h
src/ggml-hexagon/htp/hvx-base.h
src/ggml-hexagon/htp/hvx-exp.h
src/ggml-hexagon/htp/hvx-sigmoid.h
src/ggml-hexagon/htp/main.c
src/ggml-hexagon/htp/repeat-ops.c [new file with mode: 0644]
src/ggml-hexagon/htp/softmax-ops.c
src/ggml-hexagon/htp/unary-ops.c

index 19917cb114001493688561d77b5b7c292fb20439..4b8a16c363522948420ce59cdad23810c41343cf 100644 (file)
@@ -2362,6 +2362,27 @@ static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs,
     return n_bufs;
 }
 
+static inline size_t init_cont_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    // CONT is just a contiguous copy — reuse CPY op
+    req->op = HTP_OP_CPY;
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
+static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+    req->op = HTP_OP_REPEAT;
+
+    size_t n_bufs = 0;
+    n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+    n_bufs += htp_req_buff_init(&req->dst,  &bufs[n_bufs], t,         DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+    return n_bufs;
+}
+
 static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
     req->op = HTP_OP_GET_ROWS;
 
@@ -2449,12 +2470,33 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
             break;
 
         case GGML_OP_UNARY:
-            if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
+            switch (ggml_get_unary_op(t)) {
+            case GGML_UNARY_OP_SILU:
                 req->op   = HTP_OP_UNARY_SILU;
                 supported = true;
-            } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) {
+                break;
+            case  GGML_UNARY_OP_GELU:
                 req->op   = HTP_OP_UNARY_GELU;
                 supported = true;
+                break;
+            case GGML_UNARY_OP_SIGMOID:
+                req->op   = HTP_OP_UNARY_SIGMOID;
+                supported = true;
+                break;
+            case GGML_UNARY_OP_NEG:
+                req->op   = HTP_OP_UNARY_NEG;
+                supported = true;
+                break;
+            case GGML_UNARY_OP_EXP:
+                req->op   = HTP_OP_UNARY_EXP;
+                supported = true;
+                break;
+            case GGML_UNARY_OP_SOFTPLUS:
+                req->op   = HTP_OP_UNARY_SOFTPLUS;
+                supported = true;
+                break;
+            default:
+                break;
             }
             break;
 
@@ -2640,16 +2682,28 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
                 break;
             case GGML_OP_UNARY:
-                if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
-                        (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
-                    ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+                switch (ggml_get_unary_op(node)) {
+                    case GGML_UNARY_OP_NEG:
+                    case GGML_UNARY_OP_EXP:
+                    case GGML_UNARY_OP_SIGMOID:
+                    case GGML_UNARY_OP_SOFTPLUS:
+                    case GGML_UNARY_OP_SILU:
+                    case GGML_UNARY_OP_GELU:
+                        ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+                        break;
+                    default:
+                        break;
                 }
                 break;
             case GGML_OP_GLU:
-                if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
-                        (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
-                        (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
-                    ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+                switch (ggml_get_glu_op(node)) {
+                    case GGML_GLU_OP_SWIGLU:
+                    case GGML_GLU_OP_SWIGLU_OAI:
+                    case GGML_GLU_OP_GEGLU:
+                        ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+                        break;
+                    default:
+                        break;
                 }
                 break;
             case GGML_OP_SOFT_MAX:
@@ -2676,6 +2730,14 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
                 ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
                 break;
 
+            case GGML_OP_CONT:
+                ggml_hexagon_dispatch_op<init_cont_req>(sess, node, flags);
+                break;
+
+            case GGML_OP_REPEAT:
+                ggml_hexagon_dispatch_op<init_repeat_req>(sess, node, flags);
+                break;
+
             case GGML_OP_ARGSORT:
                 ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
                 break;
@@ -3006,6 +3068,39 @@ static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess,
     return true;
 }
 
+static bool ggml_hexagon_supported_cont(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    GGML_UNUSED(sess);
+    const struct ggml_tensor * src0 = op->src[0];
+
+    // CONT is same-type only, supports f32 and f16
+    if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
+
+    return true;
+}
+
+static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+    GGML_UNUSED(sess);
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * dst  = op;
+
+    // Support f32 and f16
+    if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
+
+    // src and dst must be the same type
+    if (src0->type != dst->type) return false;
+
+    // dst dims must be multiples of src dims
+    if (dst->ne[0] % src0->ne[0] != 0) return false;
+    if (dst->ne[1] % src0->ne[1] != 0) return false;
+    if (dst->ne[2] % src0->ne[2] != 0) return false;
+    if (dst->ne[3] % src0->ne[3] != 0) return false;
+
+    // require contiguous tensors (no transposition)
+    if (ggml_is_transposed(src0) || ggml_is_transposed(dst)) return false;
+
+    return true;
+}
+
 static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
     auto sess = static_cast<ggml_hexagon_session *>(dev->context);
 
@@ -3063,21 +3158,32 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
             break;
 
         case GGML_OP_UNARY:
-            {
-                const auto unary_op = ggml_get_unary_op(op);
-                if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) {
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_NEG:
+                case GGML_UNARY_OP_EXP:
+                case GGML_UNARY_OP_SIGMOID:
+                case GGML_UNARY_OP_SOFTPLUS:
+                    supp = ggml_hexagon_supported_unary(sess, op);
+                    break;
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_GELU:
                     supp = ggml_hexagon_supported_activations(sess, op);
-                }
-                break;
+                    break;
+                default:
+                    break;
             }
+            break;
         case GGML_OP_GLU:
-            {
-                const auto glu_op = ggml_get_glu_op(op);
-                if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
+            switch (ggml_get_glu_op(op)) {
+                case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_SWIGLU_OAI:
+                case GGML_GLU_OP_GEGLU:
                     supp = ggml_hexagon_supported_activations(sess, op);
-                }
-                break;
+                    break;
+                default:
+                    break;
             }
+            break;
         case GGML_OP_ROPE:
             supp = ggml_hexagon_supported_rope(sess, op);
             break;
@@ -3098,6 +3204,14 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
             supp = ggml_hexagon_supported_cpy(sess, op);
             break;
 
+        case GGML_OP_CONT:
+            supp = ggml_hexagon_supported_cont(sess, op);
+            break;
+
+        case GGML_OP_REPEAT:
+            supp = ggml_hexagon_supported_repeat(sess, op);
+            break;
+
         case GGML_OP_ARGSORT:
             supp = ggml_hexagon_supported_argsort(sess, op);
             break;
index 02d07a503d50fb405f90b25c5cf91ad84ac090a4..a490a2ce9a167a975da4ba507604e8daf3228065 100644 (file)
@@ -30,6 +30,7 @@ add_library(${HTP_LIB} SHARED
     set-rows-ops.c
     get-rows-ops.c
     cpy-ops.c
+    repeat-ops.c
     argsort-ops.c
     ssm-conv.c
 )
index 52dcc36d8f7fb0402e3100e877cbe7c8cffbe563..56bc5b622c5892e6ff512370001777faaeb8feb2 100644 (file)
@@ -53,6 +53,10 @@ enum htp_op {
     HTP_OP_RMS_NORM,
     HTP_OP_UNARY_SILU,
     HTP_OP_UNARY_GELU,
+    HTP_OP_UNARY_SIGMOID,
+    HTP_OP_UNARY_EXP,
+    HTP_OP_UNARY_NEG,
+    HTP_OP_UNARY_SOFTPLUS,
     HTP_OP_GLU_SWIGLU,
     HTP_OP_GLU_SWIGLU_OAI,
     HTP_OP_GLU_GEGLU,
@@ -69,6 +73,7 @@ enum htp_op {
     HTP_OP_SQRT,
     HTP_OP_SUM_ROWS,
     HTP_OP_SSM_CONV,
+    HTP_OP_REPEAT,
     INVALID
 };
 
index 2ef20936f1bb28847c3f53e121e9319f519ea985..f643fdc340dbb845e9505b621a4000622c83ff49 100644 (file)
@@ -57,6 +57,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
 int op_set_rows(struct htp_ops_context * octx);
 int op_get_rows(struct htp_ops_context * octx);
 int op_cpy(struct htp_ops_context * octx);
+int op_repeat(struct htp_ops_context * octx);
 int op_argsort(struct htp_ops_context * octx);
 int op_ssm_conv(struct htp_ops_context * octx);
 
index 578ca288fb65f288c53de83fab24f70957de7f35..3e6a8579b1f0b26a191fd4d712ca86779e3ee4c1 100644 (file)
@@ -3,6 +3,8 @@
 
 #include <stdbool.h>
 #include <stdint.h>
+#include <math.h>
+#include <assert.h>
 
 #include "hex-utils.h"
 #include "hvx-types.h"
index 44dfe232a3d4c0b77f679589682caa1dd6233486..84e4836dc92b2c472892dd3448ba9ff4695e567b 100644 (file)
@@ -3,6 +3,7 @@
 
 #include <stdbool.h>
 #include <stdint.h>
+#include <math.h>
 
 #include "hvx-base.h"
 #include "hvx-floor.h"
@@ -16,8 +17,8 @@
 #define EXP_LOGN2   (0x3F317218)  // ln(2)   = 0.6931471805
 #define EXP_LOG2E   (0x3FB8AA3B)  // log2(e) = 1/ln(2) = 1.4426950408
 #define EXP_ONE     (0x3f800000)  // 1.0
-#define EXP_RANGE_R (0x41a00000)  // 20.0
-#define EXP_RANGE_L (0xc1a00000)  // -20.0
+#define EXP_RANGE_R (0x42B16666)  // 88.7
+#define EXP_RANGE_L (0xC2B00000)  // -88.0 (approx log(FLT_MIN))
 
 static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
     HVX_Vector z_qf32_v;
@@ -47,12 +48,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
 
     HVX_Vector temp_v = in_vec;
 
-    // Clamp inputs to (-20.0, 20.0)
+    // Clamp inputs to (-88.0, 88.0) to avoid overflow/underflow
     HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
     HVX_VectorPred pred_cap_left  = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
 
     in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
-    in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
+    in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
 
     epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
     epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
@@ -69,12 +70,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
     // normalize before every QFloat's vmpy
     x_qf32_v  = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
 
+    x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
+
     // z = x * x;
     z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
     z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
 
-    x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
-
     // y = E4 + E5 * x;
     E_const = Q6_V_vsplat_R(EXP_COEFF_5);
     y_v     = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
@@ -145,7 +146,7 @@ static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max
     return Q6_V_vmux_QVV(pred0, inf, out);
 }
 
-static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
+static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems, bool negate) {
     int left_over       = num_elems & (VLEN_FP32 - 1);
     int num_elems_whole = num_elems - left_over;
 
@@ -162,7 +163,7 @@ static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict
     HVX_Vector vec_out = Q6_V_vzero();
 
     static const float kInf    = INFINITY;
-    static const float kMaxExp = 88.02f;  // log(INF)
+    static const float kMaxExp = 88.7f;
 
     const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
     const HVX_Vector inf     = hvx_vec_splat_f32(kInf);
index 095193277ea9448d1291033c703ed626ad124f26..37f3e7b6faec5030b12c84ec2aadb5f4f1295595 100644 (file)
@@ -2,6 +2,7 @@
 #define HVX_SIGMOID_H
 
 #include "hvx-base.h"
+#include "hvx-inverse.h"
 
 #define FAST_SIGMOID_LOG2F (0x3fb8aa3b)  // 1.442695022
 #define FAST_SIGMOID_C1    (0x3d009076)  // 0.03138777
index 3f99dbb32c478e7235be71a3c9bbb90cbaba7195..2a3f9e562b70aa7a5226a28c0f54e53aeb5c4175 100644 (file)
@@ -516,6 +516,39 @@ static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req,
     send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
 }
 
+static void proc_repeat_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+    struct dspqueue_buffer rsp_bufs[1];
+
+    // We had written to the output buffer, we'd also need to flush it
+    rsp_bufs[0].fd     = bufs[1].fd;
+    rsp_bufs[0].ptr    = bufs[1].ptr;
+    rsp_bufs[0].offset = bufs[1].offset;
+    rsp_bufs[0].size   = bufs[1].size;
+    rsp_bufs[0].flags  = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER |         // Flush HTP
+                         DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT);  // Invalidate CPU
+
+    // Setup Op context
+    struct htp_ops_context octx = { 0 };
+    octx.ctx                    = ctx;
+    octx.src0                   = req->src0;
+    octx.dst                    = req->dst;
+    octx.flags                  = req->flags;
+    octx.op                     = req->op;
+
+    // Update data pointers
+    octx.src0.data = (uint32_t) bufs[0].ptr;
+    octx.dst.data  = (uint32_t) bufs[1].ptr;
+    octx.n_threads = ctx->n_threads;
+
+    struct profile_data prof;
+    profile_start(&prof);
+
+    uint32_t rsp_status = op_repeat(&octx);
+
+    profile_stop(&prof);
+    send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
 static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
     struct dspqueue_buffer rsp_bufs[1];
 
@@ -1090,6 +1123,10 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
 
             case HTP_OP_SQR:
             case HTP_OP_SQRT:
+            case HTP_OP_UNARY_NEG:
+            case HTP_OP_UNARY_EXP:
+            case HTP_OP_UNARY_SIGMOID:
+            case HTP_OP_UNARY_SOFTPLUS:
                 if (n_bufs != 2) {
                     FARF(ERROR, "Bad unary-req buffer list");
                     continue;
@@ -1175,6 +1212,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
                 proc_cpy_req(ctx, &req, bufs);
                 break;
 
+            case HTP_OP_REPEAT:
+                if (n_bufs != 2) {
+                    FARF(ERROR, "Bad repeat-req buffer list");
+                    continue;
+                }
+                proc_repeat_req(ctx, &req, bufs);
+                break;
+
             case HTP_OP_ARGSORT:
                 if (n_bufs != 2) {
                     FARF(ERROR, "Bad argsort-req buffer list");
diff --git a/src/ggml-hexagon/htp/repeat-ops.c b/src/ggml-hexagon/htp/repeat-ops.c
new file mode 100644 (file)
index 0000000..5db06c9
--- /dev/null
@@ -0,0 +1,148 @@
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <string.h>
+
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+struct htp_repeat_context {
+    struct htp_ops_context * octx;
+
+    uint32_t nr0;
+    uint32_t nr1;
+    uint32_t nr2;
+    uint32_t nr3;
+
+    uint32_t nrows_per_thread;
+    uint32_t total_dst_rows;  // ne1 * ne2 * ne3
+
+    size_t   type_size;
+};
+
+static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) {
+    const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data;
+    struct htp_ops_context * octx = rctx->octx;
+    const struct htp_tensor * src = &octx->src0;
+    const struct htp_tensor * dst = &octx->dst;
+
+    const uint32_t ne00 = src->ne[0];
+    const uint32_t ne01 = src->ne[1];
+    const uint32_t ne02 = src->ne[2];
+    const uint32_t ne03 = src->ne[3];
+
+    const uint32_t nb00 = src->nb[0];
+    const uint32_t nb01 = src->nb[1];
+    const uint32_t nb02 = src->nb[2];
+    const uint32_t nb03 = src->nb[3];
+
+    const uint32_t ne0 = dst->ne[0];
+    const uint32_t ne1 = dst->ne[1];
+    const uint32_t ne2 = dst->ne[2];
+    const uint32_t ne3 = dst->ne[3];
+
+    const uint32_t nb0 = dst->nb[0];
+    const uint32_t nb1 = dst->nb[1];
+    const uint32_t nb2 = dst->nb[2];
+    const uint32_t nb3 = dst->nb[3];
+
+    const uint32_t nr0 = rctx->nr0;
+    const uint32_t nr1 = rctx->nr1;
+    const uint32_t nr2 = rctx->nr2;
+    const uint32_t nr3 = rctx->nr3;
+
+    const size_t row_bytes = ne00 * rctx->type_size;
+
+    const uint32_t row_start = rctx->nrows_per_thread * ith;
+    const uint32_t row_end   = MIN(row_start + rctx->nrows_per_thread, rctx->total_dst_rows);
+
+    uint64_t t1, t2;
+    t1 = HAP_perf_get_qtimer_count();
+
+    for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) {
+        // Decompose flat dst row index into (i1, i2, i3)
+        const uint32_t i1 = dst_row % ne1;
+        const uint32_t i2 = (dst_row / ne1) % ne2;
+        const uint32_t i3 = dst_row / (ne1 * ne2);
+
+        // Map to source indices (tiling)
+        const uint32_t k1 = i1 % ne01;
+        const uint32_t k2 = i2 % ne02;
+        const uint32_t k3 = i3 % ne03;
+
+        const uint8_t * src_row = (const uint8_t *) src->data + k1 * nb01 + k2 * nb02 + k3 * nb03;
+        uint8_t * dst_base      = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3;
+
+        // Tile along dimension 0
+        for (uint32_t i0 = 0; i0 < nr0; i0++) {
+            uint8_t * dst_ptr = dst_base + i0 * ne00 * nb0;
+            memcpy(dst_ptr, src_row, row_bytes);
+        }
+    }
+
+    t2 = HAP_perf_get_qtimer_count();
+
+    FARF(HIGH, "repeat %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n",
+         ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3],
+         dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+         row_start, row_end, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+int op_repeat(struct htp_ops_context * octx) {
+    const struct htp_tensor * src0 = &octx->src0;
+    struct htp_tensor *       dst  = &octx->dst;
+
+    // Validate that dst dims are multiples of src dims
+    if (dst->ne[0] % src0->ne[0] != 0 ||
+        dst->ne[1] % src0->ne[1] != 0 ||
+        dst->ne[2] % src0->ne[2] != 0 ||
+        dst->ne[3] % src0->ne[3] != 0) {
+        FARF(ERROR, "repeat: dst dims must be multiples of src dims\n");
+        return HTP_STATUS_INVAL_PARAMS;
+    }
+
+    size_t type_size;
+    switch (src0->type) {
+        case HTP_TYPE_F32: type_size = 4; break;
+        case HTP_TYPE_F16: type_size = 2; break;
+        default:
+            FARF(ERROR, "repeat: unsupported type %u\n", src0->type);
+            return HTP_STATUS_NO_SUPPORT;
+    }
+
+    const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3];
+    const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows);
+
+    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+        return HTP_STATUS_OK;
+    }
+
+    struct htp_repeat_context rctx = {
+        .octx             = octx,
+        .nr0              = dst->ne[0] / src0->ne[0],
+        .nr1              = dst->ne[1] / src0->ne[1],
+        .nr2              = dst->ne[2] / src0->ne[2],
+        .nr3              = dst->ne[3] / src0->ne[3],
+        .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads,
+        .total_dst_rows   = total_dst_rows,
+        .type_size        = type_size,
+    };
+
+    FARF(HIGH, "repeat: (%ux%ux%ux%u) -> (%ux%ux%ux%u) nr=(%u,%u,%u,%u)\n",
+         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+         dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+         rctx.nr0, rctx.nr1, rctx.nr2, rctx.nr3);
+
+    worker_pool_run_func(octx->ctx->worker_pool, repeat_job_per_thread, &rctx, n_threads);
+
+    return HTP_STATUS_OK;
+}
index 8dae7f1ed55391e8babb0e953ce3f513a08cc6af..d6356b9506fb963b9de3c3a8babeca63e9ddf354 100644 (file)
@@ -195,7 +195,7 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
                              const float max) {
     hvx_sub_scalar_f32(spad, src, max, num_elems);
 
-    hvx_exp_f32(spad, dst, num_elems, false);
+    hvx_exp_f32(dst, spad, num_elems, false);
 
     float sum = hvx_reduce_sum_f32(dst, num_elems);
 
index 5bbd5040d3dea8b53f5fffa67c22f8d7b70aa7b8..3d0928d4dce164d802593c6ca6a411abffac43c1 100644 (file)
@@ -9,6 +9,8 @@
 #include <string.h>
 
 #include "hex-dma.h"
+#include "hvx-exp.h"
+#include "hvx-sigmoid.h"
 #include "hvx-utils.h"
 
 #define GGML_COMMON_DECL_C
@@ -166,6 +168,75 @@ static void sqrt_f32(const float * restrict src,
     }
 }
 
+static void neg_f32(const float * restrict src,
+                    float * restrict dst,
+                    uint8_t * restrict spad,
+                    const uint32_t num_rows,
+                    const uint32_t row_elems,
+                    const size_t   row_size,
+                    int32_t *      op_params) {
+
+    for (uint32_t ir = 0; ir < num_rows; ir++) {
+        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);
+
+        hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f);
+    }
+}
+
+static void exp_f32(const float * restrict src,
+                    float * restrict dst,
+                    uint8_t * restrict spad,
+                    const uint32_t num_rows,
+                    const uint32_t row_elems,
+                    const size_t   row_size,
+                    int32_t *      op_params) {
+
+    for (uint32_t ir = 0; ir < num_rows; ir++) {
+        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);
+
+        hvx_exp_f32(dst_local, src_local, row_elems, false);
+    }
+}
+
+static void sigmoid_f32(const float * restrict src,
+                        float * restrict dst,
+                        uint8_t * restrict spad,
+                        const uint32_t num_rows,
+                        const uint32_t row_elems,
+                        const size_t   row_size,
+                        int32_t *      op_params) {
+
+    for (uint32_t ir = 0; ir < num_rows; ir++) {
+        const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+        uint8_t * restrict dst_local       = (uint8_t *)dst + (ir * row_size);
+
+        hvx_sigmoid_f32_aa(dst_local, src_local, row_elems);
+    }
+}
+
+static void softplus_f32(const float * restrict src,
+                         float * restrict dst,
+                         uint8_t * restrict spad,
+                         const uint32_t num_rows,
+                         const uint32_t row_elems,
+                         const size_t   row_size,
+                         int32_t *      op_params) {
+    // softplus(x) = log(1 + exp(x))
+    // Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h
+    for (uint32_t ir = 0; ir < num_rows; ir++) {
+        const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
+        float * restrict dst_f       = (float *)((uint8_t *)dst + (ir * row_size));
+
+        for (uint32_t i = 0; i < row_elems; i++) {
+            float x = src_f[i];
+            // For x > 20: softplus(x) ≈ x (avoids exp overflow)
+            dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x));
+        }
+    }
+}
+
 static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
     const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
     struct htp_ops_context * octx = uctx->octx;
@@ -247,6 +318,18 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void *
             case HTP_OP_SQRT:
                 sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
                 break;
+            case HTP_OP_UNARY_NEG:
+                neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
+            case HTP_OP_UNARY_EXP:
+                exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
+            case HTP_OP_UNARY_SIGMOID:
+                sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
+            case HTP_OP_UNARY_SOFTPLUS:
+                softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+                break;
             default:
                 break;
         }
@@ -295,6 +378,18 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
         case HTP_OP_SQRT:
             op_type = "sqrt-f32";
             break;
+        case HTP_OP_UNARY_NEG:
+            op_type = "neg-f32";
+            break;
+        case HTP_OP_UNARY_EXP:
+            op_type = "exp-f32";
+            break;
+        case HTP_OP_UNARY_SIGMOID:
+            op_type = "sigmoid-f32";
+            break;
+        case HTP_OP_UNARY_SOFTPLUS:
+            op_type = "softplus-f32";
+            break;
 
         default:
             FARF(ERROR, "Unsupported unary Op %u\n", octx->op);