return n_bufs;
}
+static inline size_t init_cont_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ // CONT is just a contiguous copy — reuse CPY op
+ req->op = HTP_OP_CPY;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
+static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ req->op = HTP_OP_REPEAT;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_GET_ROWS;
break;
case GGML_OP_UNARY:
- if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
+ switch (ggml_get_unary_op(t)) {
+ case GGML_UNARY_OP_SILU:
req->op = HTP_OP_UNARY_SILU;
supported = true;
- } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) {
+ break;
+ case GGML_UNARY_OP_GELU:
req->op = HTP_OP_UNARY_GELU;
supported = true;
+ break;
+ case GGML_UNARY_OP_SIGMOID:
+ req->op = HTP_OP_UNARY_SIGMOID;
+ supported = true;
+ break;
+ case GGML_UNARY_OP_NEG:
+ req->op = HTP_OP_UNARY_NEG;
+ supported = true;
+ break;
+ case GGML_UNARY_OP_EXP:
+ req->op = HTP_OP_UNARY_EXP;
+ supported = true;
+ break;
+ case GGML_UNARY_OP_SOFTPLUS:
+ req->op = HTP_OP_UNARY_SOFTPLUS;
+ supported = true;
+ break;
+ default:
+ break;
}
break;
ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
break;
case GGML_OP_UNARY:
- if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
- (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
- ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+ switch (ggml_get_unary_op(node)) {
+ case GGML_UNARY_OP_NEG:
+ case GGML_UNARY_OP_EXP:
+ case GGML_UNARY_OP_SIGMOID:
+ case GGML_UNARY_OP_SOFTPLUS:
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_GELU:
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+ break;
+ default:
+ break;
}
break;
case GGML_OP_GLU:
- if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
- (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
- (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
- ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+ switch (ggml_get_glu_op(node)) {
+ case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
+ case GGML_GLU_OP_GEGLU:
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+ break;
+ default:
+ break;
}
break;
case GGML_OP_SOFT_MAX:
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
break;
+ case GGML_OP_CONT:
+ ggml_hexagon_dispatch_op<init_cont_req>(sess, node, flags);
+ break;
+
+ case GGML_OP_REPEAT:
+ ggml_hexagon_dispatch_op<init_repeat_req>(sess, node, flags);
+ break;
+
case GGML_OP_ARGSORT:
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
break;
return true;
}
+static bool ggml_hexagon_supported_cont(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ GGML_UNUSED(sess);
+ const struct ggml_tensor * src0 = op->src[0];
+
+ // CONT is same-type only, supports f32 and f16
+ if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ GGML_UNUSED(sess);
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * dst = op;
+
+ // Support f32 and f16
+ if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false;
+
+ // src and dst must be the same type
+ if (src0->type != dst->type) return false;
+
+ // dst dims must be multiples of src dims
+ if (dst->ne[0] % src0->ne[0] != 0) return false;
+ if (dst->ne[1] % src0->ne[1] != 0) return false;
+ if (dst->ne[2] % src0->ne[2] != 0) return false;
+ if (dst->ne[3] % src0->ne[3] != 0) return false;
+
+ // require contiguous tensors (no transposition)
+ if (ggml_is_transposed(src0) || ggml_is_transposed(dst)) return false;
+
+ return true;
+}
+
static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
auto sess = static_cast<ggml_hexagon_session *>(dev->context);
break;
case GGML_OP_UNARY:
- {
- const auto unary_op = ggml_get_unary_op(op);
- if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) {
+ switch (ggml_get_unary_op(op)) {
+ case GGML_UNARY_OP_NEG:
+ case GGML_UNARY_OP_EXP:
+ case GGML_UNARY_OP_SIGMOID:
+ case GGML_UNARY_OP_SOFTPLUS:
+ supp = ggml_hexagon_supported_unary(sess, op);
+ break;
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_GELU:
supp = ggml_hexagon_supported_activations(sess, op);
- }
- break;
+ break;
+ default:
+ break;
}
+ break;
case GGML_OP_GLU:
- {
- const auto glu_op = ggml_get_glu_op(op);
- if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
+ switch (ggml_get_glu_op(op)) {
+ case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
+ case GGML_GLU_OP_GEGLU:
supp = ggml_hexagon_supported_activations(sess, op);
- }
- break;
+ break;
+ default:
+ break;
}
+ break;
case GGML_OP_ROPE:
supp = ggml_hexagon_supported_rope(sess, op);
break;
supp = ggml_hexagon_supported_cpy(sess, op);
break;
+ case GGML_OP_CONT:
+ supp = ggml_hexagon_supported_cont(sess, op);
+ break;
+
+ case GGML_OP_REPEAT:
+ supp = ggml_hexagon_supported_repeat(sess, op);
+ break;
+
case GGML_OP_ARGSORT:
supp = ggml_hexagon_supported_argsort(sess, op);
break;
set-rows-ops.c
get-rows-ops.c
cpy-ops.c
+ repeat-ops.c
argsort-ops.c
ssm-conv.c
)
HTP_OP_RMS_NORM,
HTP_OP_UNARY_SILU,
HTP_OP_UNARY_GELU,
+ HTP_OP_UNARY_SIGMOID,
+ HTP_OP_UNARY_EXP,
+ HTP_OP_UNARY_NEG,
+ HTP_OP_UNARY_SOFTPLUS,
HTP_OP_GLU_SWIGLU,
HTP_OP_GLU_SWIGLU_OAI,
HTP_OP_GLU_GEGLU,
HTP_OP_SQRT,
HTP_OP_SUM_ROWS,
HTP_OP_SSM_CONV,
+ HTP_OP_REPEAT,
INVALID
};
int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx);
int op_cpy(struct htp_ops_context * octx);
+int op_repeat(struct htp_ops_context * octx);
int op_argsort(struct htp_ops_context * octx);
int op_ssm_conv(struct htp_ops_context * octx);
#include <stdbool.h>
#include <stdint.h>
+#include <math.h>
+#include <assert.h>
#include "hex-utils.h"
#include "hvx-types.h"
#include <stdbool.h>
#include <stdint.h>
+#include <math.h>
#include "hvx-base.h"
#include "hvx-floor.h"
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
#define EXP_ONE (0x3f800000) // 1.0
-#define EXP_RANGE_R (0x41a00000) // 20.0
-#define EXP_RANGE_L (0xc1a00000) // -20.0
+#define EXP_RANGE_R (0x42B16666) // 88.7
+#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) {
HVX_Vector z_qf32_v;
HVX_Vector temp_v = in_vec;
- // Clamp inputs to (-20.0, 20.0)
+ // Clamp inputs to (-88.0, 88.0) to avoid overflow/underflow
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
- in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
+ in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
// normalize before every QFloat's vmpy
x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
+ x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
+
// z = x * x;
z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
- x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
-
// y = E4 + E5 * x;
E_const = Q6_V_vsplat_R(EXP_COEFF_5);
y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
return Q6_V_vmux_QVV(pred0, inf, out);
}
-static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
+static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems, bool negate) {
int left_over = num_elems & (VLEN_FP32 - 1);
int num_elems_whole = num_elems - left_over;
HVX_Vector vec_out = Q6_V_vzero();
static const float kInf = INFINITY;
- static const float kMaxExp = 88.02f; // log(INF)
+ static const float kMaxExp = 88.7f;
const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp);
const HVX_Vector inf = hvx_vec_splat_f32(kInf);
#define HVX_SIGMOID_H
#include "hvx-base.h"
+#include "hvx-inverse.h"
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
+static void proc_repeat_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[1].fd;
+ rsp_bufs[0].ptr = bufs[1].ptr;
+ rsp_bufs[0].offset = bufs[1].offset;
+ rsp_bufs[0].size = bufs[1].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.dst.data = (uint32_t) bufs[1].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = op_repeat(&octx);
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
case HTP_OP_SQR:
case HTP_OP_SQRT:
+ case HTP_OP_UNARY_NEG:
+ case HTP_OP_UNARY_EXP:
+ case HTP_OP_UNARY_SIGMOID:
+ case HTP_OP_UNARY_SOFTPLUS:
if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list");
continue;
proc_cpy_req(ctx, &req, bufs);
break;
+ case HTP_OP_REPEAT:
+ if (n_bufs != 2) {
+ FARF(ERROR, "Bad repeat-req buffer list");
+ continue;
+ }
+ proc_repeat_req(ctx, &req, bufs);
+ break;
+
case HTP_OP_ARGSORT:
if (n_bufs != 2) {
FARF(ERROR, "Bad argsort-req buffer list");
--- /dev/null
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <string.h>
+
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+struct htp_repeat_context {
+ struct htp_ops_context * octx;
+
+ uint32_t nr0;
+ uint32_t nr1;
+ uint32_t nr2;
+ uint32_t nr3;
+
+ uint32_t nrows_per_thread;
+ uint32_t total_dst_rows; // ne1 * ne2 * ne3
+
+ size_t type_size;
+};
+
+static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) {
+ const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data;
+ struct htp_ops_context * octx = rctx->octx;
+ const struct htp_tensor * src = &octx->src0;
+ const struct htp_tensor * dst = &octx->dst;
+
+ const uint32_t ne00 = src->ne[0];
+ const uint32_t ne01 = src->ne[1];
+ const uint32_t ne02 = src->ne[2];
+ const uint32_t ne03 = src->ne[3];
+
+ const uint32_t nb00 = src->nb[0];
+ const uint32_t nb01 = src->nb[1];
+ const uint32_t nb02 = src->nb[2];
+ const uint32_t nb03 = src->nb[3];
+
+ const uint32_t ne0 = dst->ne[0];
+ const uint32_t ne1 = dst->ne[1];
+ const uint32_t ne2 = dst->ne[2];
+ const uint32_t ne3 = dst->ne[3];
+
+ const uint32_t nb0 = dst->nb[0];
+ const uint32_t nb1 = dst->nb[1];
+ const uint32_t nb2 = dst->nb[2];
+ const uint32_t nb3 = dst->nb[3];
+
+ const uint32_t nr0 = rctx->nr0;
+ const uint32_t nr1 = rctx->nr1;
+ const uint32_t nr2 = rctx->nr2;
+ const uint32_t nr3 = rctx->nr3;
+
+ const size_t row_bytes = ne00 * rctx->type_size;
+
+ const uint32_t row_start = rctx->nrows_per_thread * ith;
+ const uint32_t row_end = MIN(row_start + rctx->nrows_per_thread, rctx->total_dst_rows);
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) {
+ // Decompose flat dst row index into (i1, i2, i3)
+ const uint32_t i1 = dst_row % ne1;
+ const uint32_t i2 = (dst_row / ne1) % ne2;
+ const uint32_t i3 = dst_row / (ne1 * ne2);
+
+ // Map to source indices (tiling)
+ const uint32_t k1 = i1 % ne01;
+ const uint32_t k2 = i2 % ne02;
+ const uint32_t k3 = i3 % ne03;
+
+ const uint8_t * src_row = (const uint8_t *) src->data + k1 * nb01 + k2 * nb02 + k3 * nb03;
+ uint8_t * dst_base = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3;
+
+ // Tile along dimension 0
+ for (uint32_t i0 = 0; i0 < nr0; i0++) {
+ uint8_t * dst_ptr = dst_base + i0 * ne00 * nb0;
+ memcpy(dst_ptr, src_row, row_bytes);
+ }
+ }
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "repeat %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n",
+ ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3],
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ row_start, row_end, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+int op_repeat(struct htp_ops_context * octx) {
+ const struct htp_tensor * src0 = &octx->src0;
+ struct htp_tensor * dst = &octx->dst;
+
+ // Validate that dst dims are multiples of src dims
+ if (dst->ne[0] % src0->ne[0] != 0 ||
+ dst->ne[1] % src0->ne[1] != 0 ||
+ dst->ne[2] % src0->ne[2] != 0 ||
+ dst->ne[3] % src0->ne[3] != 0) {
+ FARF(ERROR, "repeat: dst dims must be multiples of src dims\n");
+ return HTP_STATUS_INVAL_PARAMS;
+ }
+
+ size_t type_size;
+ switch (src0->type) {
+ case HTP_TYPE_F32: type_size = 4; break;
+ case HTP_TYPE_F16: type_size = 2; break;
+ default:
+ FARF(ERROR, "repeat: unsupported type %u\n", src0->type);
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3];
+ const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows);
+
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+ return HTP_STATUS_OK;
+ }
+
+ struct htp_repeat_context rctx = {
+ .octx = octx,
+ .nr0 = dst->ne[0] / src0->ne[0],
+ .nr1 = dst->ne[1] / src0->ne[1],
+ .nr2 = dst->ne[2] / src0->ne[2],
+ .nr3 = dst->ne[3] / src0->ne[3],
+ .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads,
+ .total_dst_rows = total_dst_rows,
+ .type_size = type_size,
+ };
+
+ FARF(HIGH, "repeat: (%ux%ux%ux%u) -> (%ux%ux%ux%u) nr=(%u,%u,%u,%u)\n",
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ rctx.nr0, rctx.nr1, rctx.nr2, rctx.nr3);
+
+ worker_pool_run_func(octx->ctx->worker_pool, repeat_job_per_thread, &rctx, n_threads);
+
+ return HTP_STATUS_OK;
+}
const float max) {
hvx_sub_scalar_f32(spad, src, max, num_elems);
- hvx_exp_f32(spad, dst, num_elems, false);
+ hvx_exp_f32(dst, spad, num_elems, false);
float sum = hvx_reduce_sum_f32(dst, num_elems);
#include <string.h>
#include "hex-dma.h"
+#include "hvx-exp.h"
+#include "hvx-sigmoid.h"
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
}
}
+static void neg_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params) {
+
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
+
+ hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f);
+ }
+}
+
+static void exp_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params) {
+
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
+
+ hvx_exp_f32(dst_local, src_local, row_elems, false);
+ }
+}
+
+static void sigmoid_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params) {
+
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
+
+ hvx_sigmoid_f32_aa(dst_local, src_local, row_elems);
+ }
+}
+
+static void softplus_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params) {
+ // softplus(x) = log(1 + exp(x))
+ // Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
+ const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size));
+ float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size));
+
+ for (uint32_t i = 0; i < row_elems; i++) {
+ float x = src_f[i];
+ // For x > 20: softplus(x) ≈ x (avoids exp overflow)
+ dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x));
+ }
+ }
+}
+
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
struct htp_ops_context * octx = uctx->octx;
case HTP_OP_SQRT:
sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
break;
+ case HTP_OP_UNARY_NEG:
+ neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+ break;
+ case HTP_OP_UNARY_EXP:
+ exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+ break;
+ case HTP_OP_UNARY_SIGMOID:
+ sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+ break;
+ case HTP_OP_UNARY_SOFTPLUS:
+ softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+ break;
default:
break;
}
case HTP_OP_SQRT:
op_type = "sqrt-f32";
break;
+ case HTP_OP_UNARY_NEG:
+ op_type = "neg-f32";
+ break;
+ case HTP_OP_UNARY_EXP:
+ op_type = "exp-f32";
+ break;
+ case HTP_OP_UNARY_SIGMOID:
+ op_type = "sigmoid-f32";
+ break;
+ case HTP_OP_UNARY_SOFTPLUS:
+ op_type = "softplus-f32";
+ break;
default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);