return false;
}
- // TODO: add support for non-contigiuos tensors
- if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
- return false;
- }
-
return true;
}
return true;
}
+static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * dst = op;
+
+ if (!hex_supported_src0_type(src0->type)) {
+ return false;
+ }
+ if (!hex_supported_dst_type(dst->type)) {
+ return false;
+ }
+
+ // TODO: add support for non-contigiuos tensors
+ if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
+ return false;
+ }
+
+ return true;
+}
+
static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
return true;
}
+static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0]; // values
+ const struct ggml_tensor * dst = op; // indices
+
+ if (src0->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (dst->type != GGML_TYPE_I32) {
+ return false;
+ }
+
+ if (src0->ne[0] > (16*1024)) {
+ // reject tensors with huge rows for now
+ return false;
+ }
+
+ return true;
+}
+
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const int32_t * op_params = &op->op_params[0];
case GGML_OP_SUB:
req->op = HTP_OP_SUB;
break;
+ case GGML_OP_DIV:
+ req->op = HTP_OP_DIV;
+ break;
default:
GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
break;
return n_bufs;
}
+static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ req->op = HTP_OP_ARGSORT;
+ memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
template <bool _is_src0_constant>
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
switch (t->op) {
supported = true;
break;
+ case GGML_OP_SQR:
+ req->op = HTP_OP_SQR;
+ supported = true;
+ break;
+
+ case GGML_OP_SQRT:
+ req->op = HTP_OP_SQRT;
+ supported = true;
+ break;
+
case GGML_OP_UNARY:
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
req->op = HTP_OP_UNARY_SILU;
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
req->op = HTP_OP_GLU_SWIGLU_OAI;
supported = true;
+ } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
+ req->op = HTP_OP_GLU_GEGLU;
+ supported = true;
}
break;
return n_bufs;
}
+static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+ req->op = HTP_OP_SUM_ROWS;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
req->op = HTP_OP_ROPE;
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
+ case GGML_OP_DIV:
ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
break;
case GGML_OP_ADD_ID:
case GGML_OP_SCALE:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break;
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
+ break;
+ case GGML_OP_SUM_ROWS:
+ ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
+ break;
case GGML_OP_UNARY:
if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
(ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
break;
case GGML_OP_GLU:
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
- (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
+ (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
+ (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
}
break;
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
break;
+ case GGML_OP_ARGSORT:
+ ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
+ break;
+
default:
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
}
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
+ case GGML_OP_DIV:
supp = ggml_hexagon_supported_binary(sess, op);
break;
supp = ggml_hexagon_supported_unary(sess, op);
break;
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ supp = ggml_hexagon_supported_unary(sess, op);
+ break;
+
+ case GGML_OP_SUM_ROWS:
+ supp = ggml_hexagon_supported_sum_rows(sess, op);
+ break;
+
case GGML_OP_SOFT_MAX:
supp = ggml_hexagon_supported_softmax(sess, op);
break;
case GGML_OP_GLU:
{
const auto glu_op = ggml_get_glu_op(op);
- if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) {
+ if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
supp = ggml_hexagon_supported_activations(sess, op);
}
break;
supp = ggml_hexagon_supported_cpy(sess, op);
break;
+ case GGML_OP_ARGSORT:
+ supp = ggml_hexagon_supported_argsort(sess, op);
+ break;
+
default:
break;
}
include_directories(
${HEXAGON_SDK_ROOT}/incs
${HEXAGON_SDK_ROOT}/incs/stddef
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../../include
${CMAKE_CURRENT_SOURCE_DIR}/../..
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_SOURCE_DIR}
matmul-ops.c
binary-ops.c
unary-ops.c
+ sum-rows-ops.c
softmax-ops.c
act-ops.c
rope-ops.c
set-rows-ops.c
get-rows-ops.c
cpy-ops.c
+ argsort-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
// gelu = x * sigmoid(1.702 * x) // current implementation
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
- hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
+ hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
// silu = x * sigmoid(x)
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
- hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
+ hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
+static const float GELU_COEF_A = 0.044715f;
+static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
+ const struct htp_tensor * src1,
+ struct htp_tensor * dst,
+ const int32_t * op_params,
+ struct htp_spad * src0_spad,
+ struct htp_spad * src1_spad,
+ struct htp_spad * dst_spad,
+ uint32_t nth,
+ uint32_t ith,
+ uint32_t src0_nrows_per_thread,
+ dma_queue * dma_queue) {
+ htp_act_preamble3;
+
+ size_t src0_row_size = nb01;
+ size_t src1_row_size = nb11;
+ size_t dst_row_size = nb1;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
+ const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
+ uint8_t * restrict data_dst = (uint8_t *) dst->data;
+
+ const bool src1_valid = src1->ne[0];
+ const int nc = (src1_valid) ? ne00 : ne00 / 2;
+ if (!src1_valid) {
+ const int32_t swapped = op_params[1];
+ data_src1 = data_src0;
+ src1_row_size = src0_row_size;
+
+ const size_t nc_in_bytes = nc * SIZEOF_FP32;
+ data_src0 += swapped ? nc_in_bytes : 0;
+ data_src1 += swapped ? 0 : nc_in_bytes;
+ }
+
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+ const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+
+ uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
+ uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
+ uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
+
+ // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
+ size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
+ size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
+ size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
+
+ const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+ if (BLOCK == 0) {
+ FARF(ERROR,
+ "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
+ src0_spad->size_per_thread, src0_row_size_aligned);
+ return;
+ }
+
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
+ for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
+ dma_queue_push_vtcm_to_ddr(dma_queue,
+ dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
+ dst_row_size, dst_row_size_aligned, 0);
+
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, block_size);
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
+ src1_row_size_aligned, src1_row_size, block_size);
+ }
+
+ for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
+ float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
+ float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
+
+ for (uint32_t ib = 0; ib < block_size; ib++) {
+ const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
+ const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
+ uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
+
+ // geglu tanh implementation
+ // geglu(x, g) = gelu(x) * g
+ // gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
+ hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x
+ hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A
+ hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
+ hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
+ hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
+ hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res)
+ hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
+ hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
+ hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f
+ hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g
+ }
+
+ dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
+ dst_row_size_aligned, block_size);
+
+ // prefetch N+2 loop iteration if any
+ const uint32_t pref_block = (ir + BLOCK * 2);
+ if (pref_block < src0_end_row) {
+ const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, pref_block_size);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
+ src1_row_size_aligned, src1_row_size, pref_block_size);
+ }
+ }
+
+ dma_queue_flush(dma_queue);
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+ ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
+static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = (struct htp_ops_context *) data;
+ glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
+ &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+}
+
static int execute_op_activations_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
act_op_func = unary_gelu_f32;
op_type = "gelu-f32";
break;
+
+ case HTP_OP_GLU_GEGLU:
+ act_op_func = glu_geglu_f32;
+ op_type = "geglu-f32";
+ break;
default:
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
return HTP_STATUS_NO_SUPPORT;
--- /dev/null
+#include <string.h>
+#include <stdlib.h>
+#include <math.h>
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "ggml.h"
+
+#include "hvx-utils.h"
+#include "hex-dma.h"
+
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+#ifndef MIN
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+struct htp_argsort_context {
+ struct htp_ops_context * octx;
+ uint32_t nrows_per_thread;
+};
+
+static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
+{
+ const HVX_Vector one = Q6_V_vsplat_R(1);
+ const HVX_Vector zero = Q6_V_vzero();
+
+ HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
+ HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
+ HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
+ return hvx_vec_get_i32(sum) == 32;
+}
+
+// Sorts values and mirrors swaps to indices.
+static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
+ if (left >= right) return;
+
+ int pivot_idx = (left + right) / 2;
+ float pivot = values[pivot_idx];
+ int i = left;
+ int j = right;
+
+ HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
+ while (i <= j) {
+ // Vectorized scan for i
+ while (i <= j) {
+ // Check if we have at least one full vector
+ if (i + 32 <= j) {
+ HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
+ if (all_greater_f32(pivot_vec, vals_vec)) {
+ // If all elements are < pivot, we can skip this whole block
+ i += 32;
+ continue;
+ }
+ }
+
+ // Scalar fallback / cleanup
+ if (values[i] < pivot) {
+ i++;
+ } else {
+ break;
+ }
+ }
+
+ // Vectorized scan for j
+ while (i <= j) {
+ if (j - 32 >= i) {
+ // Load 32 elements ending at j.
+ // Since we want `values[j] > pivot`, let's load from j-31 to j.
+ HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
+ if (all_greater_f32(vals_vec, pivot_vec)) {
+ j -= 32;
+ continue;
+ }
+ }
+
+ if (values[j] > pivot) {
+ j--;
+ } else {
+ break;
+ }
+ }
+
+ if (i <= j) {
+ float tmp_val = values[i];
+ values[i] = values[j];
+ values[j] = tmp_val;
+
+ int32_t tmp_idx = indices[i];
+ indices[i] = indices[j];
+ indices[j] = tmp_idx;
+ i++;
+ j--;
+ }
+ }
+
+ if (left < j) quicksort_values_indices_asc(values, indices, left, j);
+ if (i < right) quicksort_values_indices_asc(values, indices, i, right);
+}
+
+static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
+ if (left >= right) return;
+
+ int pivot_idx = (left + right) / 2;
+ float pivot = values[pivot_idx];
+ int i = left;
+ int j = right;
+
+ HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
+
+ while (i <= j) {
+ // Vectorized scan for i (values[i] > pivot)
+ while (i <= j) {
+ if (i + 32 <= j) {
+ HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
+ if (all_greater_f32(vals_vec, pivot_vec)) {
+ i += 32;
+ continue;
+ }
+ }
+
+ if (values[i] > pivot) {
+ i++;
+ } else {
+ break;
+ }
+ }
+
+ // Vectorized scan for j (values[j] < pivot)
+ while (i <= j) {
+ if (j - 32 >= i) {
+ HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
+ if (all_greater_f32(pivot_vec, vals_vec)) {
+ j -= 32;
+ continue;
+ }
+ }
+
+ if (values[j] < pivot) {
+ j--;
+ } else {
+ break;
+ }
+ }
+
+ if (i <= j) {
+ float tmp_val = values[i];
+ values[i] = values[j];
+ values[j] = tmp_val;
+
+ int32_t tmp_idx = indices[i];
+ indices[i] = indices[j];
+ indices[j] = tmp_idx;
+ i++;
+ j--;
+ }
+ }
+
+ if (left < j) quicksort_values_indices_desc(values, indices, left, j);
+ if (i < right) quicksort_values_indices_desc(values, indices, i, right);
+}
+
+static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
+ struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
+ struct htp_ops_context * octx = actx->octx;
+
+ // Unpack context
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * dst = &octx->dst;
+
+ // Scratchpad memory
+ uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
+
+ // Dimensions
+ uint32_t ne00 = src0->ne[0];
+ uint32_t ne01 = src0->ne[1];
+ uint32_t ne02 = src0->ne[2];
+ uint32_t ne03 = src0->ne[3];
+
+ uint32_t nb01 = src0->nb[1];
+ //uint32_t nb02 = src0->nb[2];
+ //uint32_t nb03 = src0->nb[3];
+
+ uint32_t nb1 = dst->nb[1];
+ //uint32_t nb2 = dst->nb[2];
+ //uint32_t nb3 = dst->nb[3];
+
+ // Sort order
+ enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
+
+ // Rows to process
+ uint32_t total_rows = ne01 * ne02 * ne03;
+ uint32_t rows_per_thread = actx->nrows_per_thread;
+ uint32_t start_row = rows_per_thread * i;
+ uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
+
+ // Scratchpad layout:
+ // We need space for one row of float data (values) and one row of int32 indices.
+ // values: ne00 * sizeof(float)
+ // indices: ne00 * sizeof(int32_t)
+ // Padded to 128 bytes.
+
+ size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
+ float * values_buf = (float *) spad;
+ int32_t * indices_buf = (int32_t *) (spad + values_size);
+
+ for (uint32_t r = start_row; r < end_row; r++) {
+ uint32_t src_offset = r * nb01;
+ uint32_t dst_offset = r * nb1;
+
+ uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
+ uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset;
+
+ hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
+ hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
+
+ // Initialize indices
+ for (uint32_t j = 0; j < ne00; j++) {
+ indices_buf[j] = j;
+ }
+
+ // Sort values and mirror swaps to indices
+ if (order == GGML_SORT_ORDER_ASC) {
+ quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
+ } else {
+ quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
+ }
+
+ // Copy indices back to DDR
+ hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
+ }
+}
+
+int op_argsort(struct htp_ops_context * octx) {
+ // Check supported types
+ if (octx->src0.type != HTP_TYPE_F32) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ // Allocate scratchpad
+ // We need 1 row of float + 1 row of int32 per thread.
+ uint32_t ne00 = octx->src0.ne[0];
+ size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
+ size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
+ size_t spad_per_thread = values_size + indices_size;
+
+ // Make sure we round up to 256 for alignment requirements
+ spad_per_thread = hex_round_up(spad_per_thread, 256);
+
+ size_t total_spad_size = spad_per_thread * octx->n_threads;
+
+ if (octx->ctx->vtcm_size < total_spad_size) {
+ FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->src0_spad.size = total_spad_size;
+ octx->src0_spad.size_per_thread = spad_per_thread;
+
+ FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
+ octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
+ octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
+ octx->src0.data, octx->dst.data);
+
+ uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
+ uint32_t n_jobs = MIN(total_rows, octx->n_threads);
+
+ struct htp_argsort_context actx;
+ actx.octx = octx;
+ actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
+
+ // Run jobs
+ worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
+
+ return HTP_STATUS_OK;
+}
#include "htp-msg.h"
#include "htp-ops.h"
-typedef void (*hvx_elemwise_f32_func)(uint8_t * data_dst, const uint8_t * src0, const uint8_t * src1, const uint32_t num_elems);
-
-static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 };
-static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f32_aa, hvx_sub_f32_aa };
+#ifndef MIN
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+// Context for binary operations
+struct htp_binary_context {
+ struct htp_ops_context * octx;
+ struct fastdiv_values dim1_div;
+ struct fastdiv_values dim2_div;
+ struct fastdiv_values dim12_div;
+
+ struct fastdiv_values src1_dim1_div; // ne11
+ struct fastdiv_values src1_dim2_div; // ne12
+ struct fastdiv_values src1_dim3_div; // ne13
+
+ uint32_t nrows_per_thread;
+ bool split_at_ne01;
+ bool split_at_ne02;
+
+ // Precomputed values
+ uint32_t block_max;
+ size_t src0_row_size_aligned;
+ size_t src1_row_size_aligned;
+ size_t dst_row_size_aligned;
+ uint32_t src1_fetch_rows; // 1 or block_max
+ uint32_t src1_dma_stride; // 0 or stride
+};
#define htp_binary_preamble \
const struct htp_tensor * src0 = &octx->src0; \
const struct htp_tensor * src1 = &octx->src1; \
- const struct htp_tensor * src2 = &octx->src2; \
struct htp_tensor * dst = &octx->dst; \
\
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne12 = src1->ne[2]; \
const uint32_t ne13 = src1->ne[3]; \
\
- const uint32_t ne0 = dst->ne[0]; \
- const uint32_t ne1 = dst->ne[1]; \
- const uint32_t ne2 = dst->ne[2]; \
- const uint32_t ne3 = dst->ne[3]; \
- \
- const uint32_t nb00 = src0->nb[0]; \
const uint32_t nb01 = src0->nb[1]; \
const uint32_t nb02 = src0->nb[2]; \
const uint32_t nb03 = src0->nb[3]; \
\
- const uint32_t nb10 = src1->nb[0]; \
const uint32_t nb11 = src1->nb[1]; \
const uint32_t nb12 = src1->nb[2]; \
const uint32_t nb13 = src1->nb[3]; \
\
- const uint32_t nb0 = dst->nb[0]; \
const uint32_t nb1 = dst->nb[1]; \
const uint32_t nb2 = dst->nb[2]; \
- const uint32_t nb3 = dst->nb[3]; \
- \
- const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+ const uint32_t nb3 = dst->nb[3];
-static void binary_job_f32_per_thread(struct htp_ops_context * octx,
- uint8_t * spad_data,
- uint32_t nth,
- uint32_t ith,
- enum htp_op op) {
- htp_binary_preamble;
+static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row,
+ uint32_t ne01, uint32_t ne02) {
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
- const size_t src0_row_size = nb01;
- const size_t src1_row_size = nb11;
- const size_t dst_row_size = nb1;
+ uint32_t rows_left = end_row - ir;
+ uint32_t block_limit = rows_left;
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
- const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
+ if (bctx->split_at_ne01) {
+ block_limit = MIN(block_limit, ne01 - i01);
+ }
+ if (bctx->split_at_ne02) {
+ uint32_t rows_in_plane = (ne02 * ne01) - rem;
+ block_limit = MIN(block_limit, rows_in_plane);
+ }
- const uint32_t src0_start_row = src0_nrows_per_thread * ith;
- const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+ return MIN(bctx->block_max, block_limit);
+}
- // no work for this thread
- if (src0_start_row >= src0_end_row) {
- return;
+// Macro for scalar op switch
+#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \
+ case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \
+ case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \
+ case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \
+ default: break; \
}
- uint64_t t1, t2;
- t1 = HAP_perf_get_qtimer_count();
+// Macro for vector op switch (All Aligned)
+#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
+ default: break; \
+ }
- int is_aligned = 1;
- int opt_path = 0;
- if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
- (0 == hex_is_aligned((void *) dst->data, VLEN))) {
- is_aligned = 0;
+// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
+#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
+ default: break; \
}
- if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
- opt_path = 1;
+
+// Macro for vector op switch (All Unaligned - generic loop used in element repeat)
+#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
+ default: break; \
}
- hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op];
+// 1. Scalar src1 (ne10 == 1)
+static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
+ htp_binary_preamble;
- uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
+
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
+
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
+
+ // Preamble
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
+
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
+ }
- const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
- uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
+ // Main loop
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ // src1 indices (broadcast/repeat)
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+ uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div);
+
+ uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+ uint32_t s1_stride = (ne11 == 1) ? 0 : nb11;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+ float val = *(float *)src1_ptr;
+ src1_ptr += s1_stride;
+ COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00);
+ }
- const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
- const uint32_t ne02_ne01 = ne02 * ne01;
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
- for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
- const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
- const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
- const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ ir_prefetch += next_block_size;
+ }
+ ir += current_block_size;
+ }
+ dma_queue_flush(q);
+}
- const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3);
- const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2);
- const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1);
+// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast
+static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
+ htp_binary_preamble;
- const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
+
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t src1_spad_half = octx->src1_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
+
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
+
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint32_t i13 = (ne13 == 1) ? 0 : i03;
+ uint32_t i12 = (ne12 == 1) ? 0 : i02;
+ uint32_t i11 = (ne11 == 1) ? 0 : i01;
+
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
+
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
+ }
- if (ir + 1 < src0_end_row) {
- hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1);
- if (src1_row_size == src0_row_size) {
- hex_l2fetch(src1_ptr, src1_row_size, src1_row_size, 1);
- }
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+ uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+ COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
}
- const uint32_t nr0 = ne00 / ne10;
- if (nr0 > 1) {
- if ((1 == is_aligned) && (nr0 == ne00)) {
- hvx_splat_f32_a(spad_data_th, *(float *) src1_ptr, nr0);
- } else {
- for (uint32_t r = 0; r < nr0; r++) {
- memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11);
- }
- }
- func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, ne00);
- } else {
- func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+
+ uint32_t p13 = (ne13 == 1) ? 0 : p03;
+ uint32_t p12 = (ne12 == 1) ? 0 : p02;
+ uint32_t p11 = (ne11 == 1) ? 0 : p01;
+
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+ uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
+
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size);
+
+ ir_prefetch += next_block_size;
}
-
- src0_ptr += src0_row_size;
- dst_ptr += dst_row_size;
+ ir += current_block_size;
}
-
- t2 = HAP_perf_get_qtimer_count();
-
- FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
- ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
- (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+ dma_queue_flush(q);
}
-static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
- uint8_t * spad_data,
- uint32_t nth,
- uint32_t ith,
- hvx_elemwise_f32_func func_HVX) {
+// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1)
+static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
- const size_t src0_row_size = nb01;
- const size_t src1_row_size = nb11;
- const size_t dst_row_size = nb1;
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
- const uint32_t src0_start_row = src0_nrows_per_thread * ith;
- const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
- // no work for this thread
- if (src0_start_row >= src0_end_row) {
- return;
- }
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
- uint64_t t1, t2;
- t1 = HAP_perf_get_qtimer_count();
+ void * s1_ptr = (void *) src1_spad;
- const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
- const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
- uint8_t * restrict data_dst = (uint8_t *) dst->data;
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
- const uint32_t ne02_ne01 = ne02 * ne01;
- for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
- // src0 indices
- const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
- const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
- const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
- // src1 indices
- const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]);
- assert(i11 >= 0 && i11 < ne11);
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
- float * restrict dst_ptr = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1);
- const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01);
- const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11);
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
+ }
- if (ir + 1 < src0_end_row) {
- hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1);
- if (src1_row_size == src0_row_size) {
- hex_l2fetch(src1_ptr + ne10, src1_row_size, src1_row_size, 1);
- }
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+ COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
}
- const uint32_t nr0 = ne00 / ne10;
- if (nr0 > 1) {
- for (uint32_t r = 0; r < nr0; r++) {
- memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10);
- }
- func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data, ne00);
- } else {
- func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ ir_prefetch += next_block_size;
}
+ ir += current_block_size;
+ }
+ dma_queue_flush(q);
+}
+
+// 4. Vector Complex (ne10 == ne00, complex broadcast)
+static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
+ htp_binary_preamble;
+
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
+
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
+
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
+
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
+
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
}
- t2 = HAP_perf_get_qtimer_count();
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint32_t r_i01 = i01 + r;
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+ uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
+
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+
+ // Read src1 from DDR (unaligned)
+ COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00);
+ }
- FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth,
- src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
- src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1],
- dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ ir_prefetch += next_block_size;
+ }
+ ir += current_block_size;
+ }
+ dma_queue_flush(q);
}
-static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
+// 5. Element Repeat (ne10 != ne00)
+static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
+ htp_binary_preamble;
- switch (octx->op) {
- case HTP_OP_MUL:
- case HTP_OP_ADD:
- case HTP_OP_SUB:
- binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op);
- break;
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
+
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
+
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
+
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
+
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
+ }
- case HTP_OP_ADD_ID:
- binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32);
- break;
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint32_t r_i01 = i01 + r;
+ uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div);
+ uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div);
+ uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div);
+
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11;
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+
+ // Repeat src1 row
+ for (uint32_t c = 0; c < ne00; c += ne10) {
+ uint32_t len = MIN(ne10, ne00 - c);
+ // Use UUU for speed and simplicity
+ COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len);
+ }
+ }
- default:
- FARF(ERROR, "Unknown Binary Op %u", octx->op);
- break;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ ir_prefetch += next_block_size;
+ }
+ ir += current_block_size;
}
+ dma_queue_flush(q);
}
-static int execute_op_binary_f32(struct htp_ops_context * octx) {
- int err = HTP_STATUS_OK;
+// 6. ADD_ID (src1 gathered via src2 indices)
+static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_binary_context * bctx = (struct htp_binary_context *) data;
+ struct htp_ops_context * octx = bctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
+ const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;
- worker_callback_t binary_op_func;
- const char * op_type = NULL;
-
- switch (octx->op) {
- case HTP_OP_MUL:
- binary_op_func = binary_job_dispatcher_f32;
- op_type = "mul-f32";
- break;
-
- case HTP_OP_ADD:
- binary_op_func = binary_job_dispatcher_f32;
- op_type = "add-f32";
- break;
-
- case HTP_OP_SUB:
- binary_op_func = binary_job_dispatcher_f32;
- op_type = "sub-f32";
- break;
-
- case HTP_OP_ADD_ID:
- binary_op_func = binary_job_dispatcher_f32;
- op_type = "add-id-f32";
- break;
-
- default:
- FARF(ERROR, "Unsupported binary-Op %u\n", octx->op);
- return HTP_STATUS_NO_SUPPORT;
+ const uint32_t ne00 = src0->ne[0];
+ const uint32_t ne01 = src0->ne[1];
+ const uint32_t ne02 = src0->ne[2];
+ const uint32_t ne03 = src0->ne[3];
+ const uint32_t ne11 = src1->ne[1]; // for bounds check
+
+ const uint32_t nb01 = src0->nb[1];
+ const uint32_t nb02 = src0->nb[2];
+ const uint32_t nb03 = src0->nb[3];
+ const uint32_t nb11 = src1->nb[1]; // src1 row stride
+ const uint32_t nb1 = dst->nb[1];
+ const uint32_t nb2 = dst->nb[2];
+ const uint32_t nb3 = dst->nb[3];
+
+ const uint32_t total_rows = ne01 * ne02 * ne03;
+ const uint32_t start_row = bctx->nrows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
+ if (start_row >= end_row) return;
+
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+ size_t src0_spad_half = octx->src0_spad.size_per_thread / 2;
+ size_t dst_spad_half = octx->dst_spad.size_per_thread / 2;
+
+ dma_queue * q = octx->ctx->dma[ith];
+ uint32_t ir_prefetch = start_row;
+ int spad_idx = 0;
+
+ for (int k = 0; k < 2 && ir_prefetch < end_row; k++) {
+ uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ rem = ir_prefetch - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01;
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+
+ uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half;
+ uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
+
+ dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ ir_prefetch += current_block_size;
+ spad_idx ^= 1;
+ }
+
+ for (uint32_t ir = start_row; ir < end_row; ) {
+ uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02);
+ uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src;
+ uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst;
+
+ uint32_t i03, i02, i01, rem;
+ i03 = fastdiv(ir, &bctx->dim12_div);
+ rem = ir - i03 * (ne02 * ne01);
+ i02 = fastdiv(rem, &bctx->dim1_div);
+ i01 = rem - i02 * ne01;
+
+ for (uint32_t r = 0; r < current_block_size; r++) {
+ uint32_t r_i01 = i01 + r; // linear within block since we split at ne01
+
+ const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]);
+
+ uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11;
+ uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
+ uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
+
+ hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00);
+ }
+
+ uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+
+ if (ir_prefetch < end_row) {
+ uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
+ uint32_t p03, p02, p01, prem;
+ p03 = fastdiv(ir_prefetch, &bctx->dim12_div);
+ prem = ir_prefetch - p03 * (ne02 * ne01);
+ p02 = fastdiv(prem, &bctx->dim1_div);
+ p01 = prem - p02 * ne01;
+ uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ ir_prefetch += next_block_size;
+ }
+ ir += current_block_size;
}
+ dma_queue_flush(q);
+}
+
+static int execute_op_binary_f32(struct htp_ops_context * octx) {
+ const struct htp_tensor * src0 = &octx->src0;
+ const struct htp_tensor * src1 = &octx->src1;
+ struct htp_tensor * dst = &octx->dst;
- const int n_threads = octx->n_threads;
+ const uint32_t n_threads = octx->n_threads;
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
- const size_t src0_row_size = src0->nb[1];
- const size_t src1_row_size = src1->nb[1];
- const size_t dst_row_size = dst->nb[1];
+ // Use packed row sizes for VTCM allocation
+ const size_t src0_row_size = src0->ne[0] * sizeof(float);
+ const size_t src1_row_size = src1->ne[0] * sizeof(float);
+ const size_t dst_row_size = dst->ne[0] * sizeof(float);
+
+ // Align to VLEN
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+ size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
+
+ bool is_add_id = (octx->op == HTP_OP_ADD_ID);
+ bool is_scalar = !is_add_id && (src1->ne[0] == 1);
+
+ // Determine which kernel we will use to alloc memory and dispatch
+ bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] &&
+ (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
+ (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
+ (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
+
+ bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
+ bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]);
+ bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]);
+
+ size_t spad_row_total;
+ if (is_scalar) {
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+ } else if (is_row_bcast) {
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+ } else if (use_vector_same) {
+ spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned);
+ } else if (is_add_id) {
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly
+ } else {
+ spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+ }
- // VTCM scratchpads for all tensors
- octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
- octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
- octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
+ size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total);
+ // Adjust for static src1 in row_bcast case
+ if (is_row_bcast) {
+ size_t needed_static = src1_row_size_aligned;
+ if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL;
+ size_t avail = octx->ctx->vtcm_size - needed_static;
+ rows_per_buffer = avail / (n_threads * spad_row_total);
+ }
- size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
+ if (rows_per_buffer < 1) {
+ FARF(ERROR, "binary-f32: VTCM too small\n");
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
- FARF(HIGH,
- "%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
- op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
- src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
- octx->dst_spad.size);
+ octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned;
+ octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned;
- // Make sure the reserved vtcm size is sufficient
- if (octx->ctx->vtcm_size < spad_size) {
- FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
- octx->ctx->vtcm_size, spad_size);
+ if (is_scalar || use_complex || use_repeat || is_add_id) {
+ octx->src1_spad.size_per_thread = 0;
+ } else if (is_row_bcast) {
+ octx->src1_spad.size_per_thread = 0;
+ } else {
+ octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned;
+ }
+
+ octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
+ if (is_row_bcast) {
+ octx->src1_spad.size = src1_row_size_aligned;
+ } else {
+ octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread;
+ }
+ octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
+
+ if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) {
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
- if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
+ if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ return HTP_STATUS_OK;
+ }
+
+ uint32_t n_jobs = MIN(n_threads, src0_nrows);
- octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ dma_queue * q = octx->ctx->dma[0];
+ if (is_row_bcast) {
+ dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1);
+ }
- octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]);
- octx->src0_div3 = init_fastdiv_values(src0->ne[3]);
- octx->src0_div2 = init_fastdiv_values(src0->ne[2]);
- octx->src0_div1 = init_fastdiv_values(src0->ne[1]);
+ struct htp_binary_context bctx;
+ bctx.octx = octx;
+ bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ bctx.block_max = rows_per_buffer;
+ bctx.src0_row_size_aligned = src0_row_size_aligned;
+ bctx.src1_row_size_aligned = src1_row_size_aligned;
+ bctx.dst_row_size_aligned = dst_row_size_aligned;
- octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]);
- octx->src1_div3 = init_fastdiv_values(src1->ne[3]);
- octx->src1_div2 = init_fastdiv_values(src1->ne[2]);
- octx->src1_div1 = init_fastdiv_values(src1->ne[1]);
+ bctx.dim1_div = init_fastdiv_values(src0->ne[1]);
+ bctx.dim2_div = init_fastdiv_values(src0->ne[2]);
+ bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]);
- worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
- }
+ bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]);
+ bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]);
+ bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]);
- return err;
-}
+ bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]);
+ bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]);
-int op_binary(struct htp_ops_context * octx) {
- int err = HTP_STATUS_OK;
+ bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]);
+ bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]);
- switch (octx->src0.type) {
- case HTP_TYPE_F32:
- err = execute_op_binary_f32(octx);
- break;
+ bctx.split_at_ne01 = (src0->ne[2] > 1) &&
+ ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1);
- default:
- err = HTP_STATUS_NO_SUPPORT;
- break;
+ bctx.split_at_ne02 = (src0->ne[3] > 1) &&
+ ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2);
+
+ // Precompute specific kernel parameters
+ if (use_vector_same) {
+ bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1];
+ bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer;
}
- return err;
+ worker_callback_t worker_func;
+ if (is_add_id) worker_func = binary_job_add_id;
+ else if (is_scalar) worker_func = binary_job_scalar;
+ else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast;
+ else if (use_vector_same) worker_func = binary_job_vector_same_shape;
+ else if (use_complex) worker_func = binary_job_vector_complex;
+ else worker_func = binary_job_element_repeat;
+
+ if (is_row_bcast) {
+ dma_queue_pop(q);
+ }
+
+ worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs);
+
+ return HTP_STATUS_OK;
+}
+
+int op_binary(struct htp_ops_context * octx) {
+ if (octx->src0.type == HTP_TYPE_F32) {
+ return execute_op_binary_f32(octx);
+ }
+ return HTP_STATUS_NO_SUPPORT;
}
HTP_TYPE_COUNT
};
-// These values are manually translated over to HTP
-// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
+// Do not reorder first 4 (used as an index)
enum htp_op {
- HTP_OP_MUL = 0,
- HTP_OP_ADD = 1,
- HTP_OP_SUB = 2,
- HTP_OP_DIV = 3,
- HTP_OP_MUL_MAT = 4,
- HTP_OP_MUL_MAT_ID = 5,
- HTP_OP_RMS_NORM = 6,
- HTP_OP_UNARY_SILU = 7,
- HTP_OP_UNARY_GELU = 8,
- HTP_OP_GLU_SWIGLU = 9,
- HTP_OP_GLU_SWIGLU_OAI = 10,
- HTP_OP_SOFTMAX = 11,
- HTP_OP_ADD_ID = 12,
- HTP_OP_ROPE = 13,
- HTP_OP_FLASH_ATTN_EXT = 14,
- HTP_OP_SET_ROWS = 15,
- HTP_OP_SCALE = 16,
- HTP_OP_GET_ROWS = 17,
- HTP_OP_CPY = 18,
+ HTP_OP_MUL = 0,
+ HTP_OP_ADD = 1,
+ HTP_OP_SUB = 2,
+ HTP_OP_DIV = 3,
+ HTP_OP_MUL_MAT,
+ HTP_OP_MUL_MAT_ID,
+ HTP_OP_RMS_NORM,
+ HTP_OP_UNARY_SILU,
+ HTP_OP_UNARY_GELU,
+ HTP_OP_GLU_SWIGLU,
+ HTP_OP_GLU_SWIGLU_OAI,
+ HTP_OP_GLU_GEGLU,
+ HTP_OP_SOFTMAX,
+ HTP_OP_ADD_ID,
+ HTP_OP_ROPE,
+ HTP_OP_FLASH_ATTN_EXT,
+ HTP_OP_SET_ROWS,
+ HTP_OP_GET_ROWS,
+ HTP_OP_SCALE,
+ HTP_OP_CPY,
+ HTP_OP_ARGSORT,
+ HTP_OP_SQR,
+ HTP_OP_SQRT,
+ HTP_OP_SUM_ROWS,
INVALID
};
-static inline size_t htp_type_block_size(uint32_t t) {
+static inline size_t htp_t_block_size(uint32_t t) {
switch (t) {
case HTP_TYPE_F32:
return 1;
return 0;
}
-static const char * htp_type_name(uint32_t t) {
- switch (t) {
- case HTP_TYPE_F32:
- return "fp32";
- case HTP_TYPE_F16:
- return "fp16";
- case HTP_TYPE_Q4_0:
- return "q4_0";
- case HTP_TYPE_Q8_0:
- return "q8_0";
- case HTP_TYPE_MXFP4:
- return "mxfp4";
- }
- return 0;
-}
-
// Internal types
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
int op_matmul_id(struct htp_ops_context * octx);
int op_binary(struct htp_ops_context * octx);
int op_unary(struct htp_ops_context * octx);
+int op_sum_rows(struct htp_ops_context * octx);
int op_activations(struct htp_ops_context * octx);
int op_softmax(struct htp_ops_context * octx);
int op_add_id(struct htp_ops_context * octx);
int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx);
int op_cpy(struct htp_ops_context * octx);
+int op_argsort(struct htp_ops_context * octx);
#endif /* HTP_OPS_H */
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
-// ADD variants
-
-static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((unsigned long) dst % 128 == 0);
- assert((unsigned long) src0 % 128 == 0);
- assert((unsigned long) src1 % 128 == 0);
- hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD);
-}
-
-static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((unsigned long) dst % 128 == 0);
- assert((unsigned long) src0 % 128 == 0);
- hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD);
-}
-
-static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((unsigned long) src0 % 128 == 0);
- assert((unsigned long) src1 % 128 == 0);
- hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD);
-}
-
-static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD);
-}
-
-// SUB variants
-
-static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((unsigned long) dst % 128 == 0);
- assert((unsigned long) src0 % 128 == 0);
- assert((unsigned long) src1 % 128 == 0);
- hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB);
-}
-
-static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((unsigned long) dst % 128 == 0);
- assert((unsigned long) src0 % 128 == 0);
- hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB);
-}
-
-static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((unsigned long) src0 % 128 == 0);
- assert((unsigned long) src1 % 128 == 0);
- hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB);
-}
-
-static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB);
-}
-
-// MUL variants
-
-static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((unsigned long) dst % 128 == 0);
- assert((unsigned long) src0 % 128 == 0);
- assert((unsigned long) src1 % 128 == 0);
- hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL);
-}
-
-static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((unsigned long) dst % 128 == 0);
- assert((unsigned long) src0 % 128 == 0);
- hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL);
-}
-
-static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((unsigned long) src0 % 128 == 0);
- assert((unsigned long) src1 % 128 == 0);
- hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL);
-}
-
-static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL);
-}
-
-// Dispatchers
-
-static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
- if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
- if (hex_is_aligned((void *) src1, 128)) {
- hvx_add_f32_aa(dst, src0, src1, num_elems);
- } else {
- hvx_add_f32_au(dst, src0, src1, num_elems);
- }
- } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
- hvx_add_f32_ua(dst, src0, src1, num_elems);
- } else {
- hvx_add_f32_uu(dst, src0, src1, num_elems);
- }
-}
-
-static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
- if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
- if (hex_is_aligned((void *) src1, 128)) {
- hvx_sub_f32_aa(dst, src0, src1, num_elems);
- } else {
- hvx_sub_f32_au(dst, src0, src1, num_elems);
- }
- } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
- hvx_sub_f32_ua(dst, src0, src1, num_elems);
- } else {
- hvx_sub_f32_uu(dst, src0, src1, num_elems);
- }
-}
-
-static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
- if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
- if (hex_is_aligned((void *) src1, 128)) {
- hvx_mul_f32_aa(dst, src0, src1, num_elems);
- } else {
- hvx_mul_f32_au(dst, src0, src1, num_elems);
- }
- } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
- hvx_mul_f32_ua(dst, src0, src1, num_elems);
- } else {
- hvx_mul_f32_uu(dst, src0, src1, num_elems);
- }
-}
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
+static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ assert((uintptr_t) src0 % 128 == 0); \
+ assert((uintptr_t) src1 % 128 == 0); \
+ hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ assert((uintptr_t) src0 % 128 == 0); \
+ hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ assert((uintptr_t) src1 % 128 == 0); \
+ hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) src0 % 128 == 0); \
+ assert((uintptr_t) src1 % 128 == 0); \
+ hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) src0 % 128 == 0); \
+ hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) src1 % 128 == 0); \
+ hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
+} \
+
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
+
+// Dispatcher logic
+#define HVX_BINARY_DISPATCHER(OP_NAME) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
+ if (hex_is_aligned((void *) dst, 128)) { \
+ if (hex_is_aligned((void *) src0, 128)) { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
+ else OP_NAME##_aau(dst, src0, src1, num_elems); \
+ } else { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
+ else OP_NAME##_auu(dst, src0, src1, num_elems); \
+ } \
+ } else { \
+ if (hex_is_aligned((void *) src0, 128)) { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
+ else OP_NAME##_uau(dst, src0, src1, num_elems); \
+ } else { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
+ else OP_NAME##_uuu(dst, src0, src1, num_elems); \
+ } \
+ } \
+}
+
+HVX_BINARY_DISPATCHER(hvx_add_f32)
+HVX_BINARY_DISPATCHER(hvx_sub_f32)
+HVX_BINARY_DISPATCHER(hvx_mul_f32)
// Mul-Mul Optimized
-
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
}
}
+//
+// Square
+//
+
+#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const uint32_t elem_size = sizeof(float); \
+ const uint32_t epv = 128 / elem_size; \
+ const uint32_t nvec = n / epv; \
+ const uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
+ } \
+ if (nloe) { \
+ HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
+ vec_store((void *) &vdst[i], nloe * elem_size, v); \
+ } \
+ } while(0)
+
+static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) src % 128 == 0);
+ hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
+ if (hex_is_aligned((void *) dst, 128)) {
+ if (hex_is_aligned((void *) src, 128)) {
+ hvx_sqr_f32_aa(dst, src, num_elems);
+ } else {
+ hvx_sqr_f32_au(dst, src, num_elems);
+ }
+ } else {
+ if (hex_is_aligned((void *) src, 128)) {
+ hvx_sqr_f32_ua(dst, src, num_elems);
+ } else {
+ hvx_sqr_f32_uu(dst, src, num_elems);
+ }
+ }
+}
+
#undef HVX_OP_ADD
#undef HVX_OP_SUB
#undef HVX_OP_MUL
#undef hvx_scalar_loop_body
#undef HVX_OP_MIN_SCALAR
#undef HVX_OP_CLAMP_SCALAR
+#undef DEFINE_HVX_BINARY_OP_VARIANTS
+#undef HVX_BINARY_DISPATCHER
#endif // HVX_ARITH_H
return x;
}
+static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
+ int32_t __attribute__((aligned(128))) x;
+ hvx_vec_store_a(&x, 4, v);
+ return x;
+}
+
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
// abs by clearing the fp16 sign bit
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
- const HVX_Vector zero = Q6_V_vsplat_R(0); \
- \
const uint32_t elem_size = sizeof(__fp16); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
--- /dev/null
+#ifndef HVX_DIV_H
+#define HVX_DIV_H
+
+#include <HAP_farf.h>
+
+#include <math.h>
+#include <string.h>
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "hvx-base.h"
+#include "hex-utils.h"
+#include "hvx-inverse.h"
+#include "hvx-arith.h"
+
+#if __HVX_ARCH__ < 79
+#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#else
+#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#endif
+
+#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src0_type * restrict vsrc0 = (src0_type *) src0; \
+ src1_type * restrict vsrc1 = (src1_type *) src1; \
+ \
+ const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
+ \
+ const uint32_t nvec = n / VLEN_FP32; \
+ const uint32_t nloe = n % VLEN_FP32; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
+ HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
+ vdst[i] = res; \
+ } \
+ if (nloe) { \
+ HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
+ HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
+ vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \
+ } \
+ } while(0)
+
+// 3-letter suffix variants
+static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) dst % 128 == 0);
+ assert((uintptr_t) src0 % 128 == 0);
+ assert((uintptr_t) src1 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) dst % 128 == 0);
+ assert((uintptr_t) src0 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) dst % 128 == 0);
+ assert((uintptr_t) src1 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) dst % 128 == 0);
+ hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) src0 % 128 == 0);
+ assert((uintptr_t) src1 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) src0 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ assert((uintptr_t) src1 % 128 == 0);
+ hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
+ hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
+ if (hex_is_aligned((void *) dst, 128)) {
+ if (hex_is_aligned((void *) src0, 128)) {
+ if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
+ else hvx_div_f32_aau(dst, src0, src1, num_elems);
+ } else {
+ if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
+ else hvx_div_f32_auu(dst, src0, src1, num_elems);
+ }
+ } else {
+ if (hex_is_aligned((void *) src0, 128)) {
+ if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
+ else hvx_div_f32_uau(dst, src0, src1, num_elems);
+ } else {
+ if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
+ else hvx_div_f32_uuu(dst, src0, src1, num_elems);
+ }
+ }
+}
+
+#undef HVX_OP_MUL
+
+#endif // HVX_DIV_H
} \
} while(0)
+#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const uint32_t epv = 128 / sizeof(float); \
+ const uint32_t nvec = n / epv; \
+ const uint32_t nloe = n % epv; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \
+ } \
+ if (nloe) { \
+ HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \
+ vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
+ } \
+ } while(0)
+
static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
+static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
#endif /* HVX_SIGMOID_H */
#define RSQRT_ONE_HALF 0x3f000000 // 0.5
#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
+#if __HVX_ARCH__ < 79
+#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#else
+#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#endif
+
static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
//Algorithm :
// x2 = input*0.5
// y = * (long *) &input
- // y = 0x5f3759df - (y>>2)
+ // y = 0x5f3759df - (y>>1)
// y = y*(threehalfs - x2*y*y)
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
return Q6_Vsf_equals_Vqf32(temp);
}
+// Compute sqrt(x) as x*inv_sqrt(x)
+#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const uint32_t nvec = n / VLEN_FP32; \
+ const uint32_t nloe = n % VLEN_FP32; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
+ HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
+ vdst[i] = sqrt_res; \
+ } \
+ if (nloe) { \
+ HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
+ HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
+ vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \
+ } \
+ } while(0)
+
+static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ assert((unsigned long) src % 128 == 0);
+ hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) dst % 128 == 0);
+ hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+
+static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ assert((unsigned long) src % 128 == 0);
+ hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
+ if ((unsigned long) dst % 128 == 0) {
+ if ((unsigned long) src % 128 == 0) {
+ hvx_sqrt_f32_aa(dst, src, num_elems);
+ } else {
+ hvx_sqrt_f32_au(dst, src, num_elems);
+ }
+ } else {
+ if ((unsigned long) src % 128 == 0) {
+ hvx_sqrt_f32_ua(dst, src, num_elems);
+ } else {
+ hvx_sqrt_f32_uu(dst, src, num_elems);
+ }
+ }
+}
+
#endif /* HVX_SQRT_H */
#include "hvx-sigmoid.h"
#include "hvx-sqrt.h"
#include "hvx-arith.h"
+#include "hvx-div.h"
#include "hvx-base.h"
#endif /* HVX_UTILS_H */
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
+static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[1].fd;
+ rsp_bufs[0].ptr = bufs[1].ptr;
+ rsp_bufs[0].offset = bufs[1].offset;
+ rsp_bufs[0].size = bufs[1].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.dst.data = (uint32_t) bufs[1].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_argsort(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
+static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[1].fd;
+ rsp_bufs[0].ptr = bufs[1].ptr;
+ rsp_bufs[0].offset = bufs[1].offset;
+ rsp_bufs[0].size = bufs[1].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.dst.data = (uint32_t) bufs[1].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_sum_rows(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
static void proc_activations_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
case HTP_OP_MUL:
case HTP_OP_ADD:
case HTP_OP_SUB:
+ case HTP_OP_DIV:
if (n_bufs != 3) {
FARF(ERROR, "Bad binary-req buffer list");
continue;
proc_unary_req(ctx, &req, bufs);
break;
+ case HTP_OP_SQR:
+ case HTP_OP_SQRT:
+ if (n_bufs != 2) {
+ FARF(ERROR, "Bad unary-req buffer list");
+ continue;
+ }
+
+ proc_unary_req(ctx, &req, bufs);
+ break;
+
+ case HTP_OP_SUM_ROWS:
+ if (n_bufs != 2) {
+ FARF(ERROR, "Bad unary-req buffer list");
+ continue;
+ }
+
+ proc_sum_rows_req(ctx, &req, bufs);
+ break;
+
case HTP_OP_UNARY_SILU:
case HTP_OP_UNARY_GELU:
if (n_bufs != 2) {
case HTP_OP_GLU_SWIGLU:
case HTP_OP_GLU_SWIGLU_OAI:
case HTP_OP_SOFTMAX:
+ case HTP_OP_GLU_GEGLU:
if ((n_bufs != 2) && (n_bufs != 3)) {
FARF(ERROR, "Bad act-req buffer list");
continue;
proc_cpy_req(ctx, &req, bufs);
break;
+ case HTP_OP_ARGSORT:
+ if (n_bufs != 2) {
+ FARF(ERROR, "Bad argsort-req buffer list");
+ continue;
+ }
+ proc_argsort_req(ctx, &req, bufs);
+ break;
+
default:
FARF(ERROR, "Unknown Op %u", req.op);
break;
--- /dev/null
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#include <HAP_farf.h>
+#include <HAP_perf.h>
+
+#include <string.h>
+#include <math.h>
+
+#include "hex-dma.h"
+#include "hvx-utils.h"
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+
+
+#define sum_rows_preamble \
+ struct htp_tensor *src0 = &octx->src0;\
+ struct htp_tensor *dst = &octx->dst; \
+ \
+ const uint32_t ne00 = src0->ne[0]; \
+ const uint32_t ne01 = src0->ne[1]; \
+ const uint32_t ne02 = src0->ne[2]; \
+ const uint32_t ne03 = src0->ne[3]; \
+ \
+ const uint32_t nb00 = src0->nb[0]; \
+ const uint32_t nb01 = src0->nb[1]; \
+ const uint32_t nb02 = src0->nb[2]; \
+ const uint32_t nb03 = src0->nb[3]; \
+ \
+ const uint32_t ne0 = dst->ne[0]; \
+ const uint32_t ne1 = dst->ne[1]; \
+ const uint32_t ne2 = dst->ne[2]; \
+ const uint32_t ne3 = dst->ne[3]; \
+ \
+ const uint32_t nb0 = dst->nb[0]; \
+ const uint32_t nb1 = dst->nb[1]; \
+ const uint32_t nb2 = dst->nb[2]; \
+ const uint32_t nb3 = dst->nb[3]; \
+
+static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+ sum_rows_preamble;
+
+ const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+ const size_t src0_row_size = nb01;
+ const size_t dst_row_size = nb1;
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return HTP_STATUS_OK;
+ }
+
+ int opt_path = 0;
+ if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
+ opt_path = 1;
+ }
+
+ const uint8_t * restrict data_src = (const uint8_t *) src0->data;
+ uint8_t * restrict data_dst = (uint8_t *) dst->data;
+
+ const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
+ float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
+
+ for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
+ const float * restrict src_local = src_th + (ir * ne00);
+
+ if (ir + 1 < src0_nrows_per_thread) {
+ hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
+ }
+
+ if (1 == opt_path) {
+ dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
+ } else {
+ dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
+ }
+ }
+
+ return HTP_STATUS_OK;
+}
+
+static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
+ sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
+}
+
+int op_sum_rows(struct htp_ops_context * octx) {
+ sum_rows_preamble;
+
+ if (octx->src0.type != HTP_TYPE_F32) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+ return HTP_STATUS_OK;
+ }
+
+ const int n_threads = octx->n_threads;
+ const uint32_t src0_nrows = ne01 * ne02 * ne03;
+
+ uint32_t n_jobs = MIN(n_threads, src0_nrows);
+ octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+
+ worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
+
+ return HTP_STATUS_OK;
+}
+
}
}
+static void sqr_htp_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params,
+ int opt_path) {
+
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
+ const float * restrict src_local = src + (ir * row_elems);
+ float * restrict dst_local = dst + (ir * row_elems);
+
+ if (ir + 1 < num_rows) {
+ hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
+ }
+
+ if (1 == opt_path) {
+ hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+ } else {
+ hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+ }
+ }
+}
+
+static void sqrt_htp_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params,
+ int opt_path) {
+
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
+ const float * restrict src_local = src + (ir * row_elems);
+ float * restrict dst_local = dst + (ir * row_elems);
+
+ if (ir + 1 < num_rows) {
+ hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
+ }
+
+ if (1 == opt_path) {
+ hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+ } else {
+ hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
+ }
+ }
+}
+
static void unary_job_f32_per_thread(const struct htp_tensor * src,
struct htp_tensor * dst,
uint8_t * spad,
case HTP_OP_SCALE:
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
+ case HTP_OP_SQR:
+ sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
+ break;
+ case HTP_OP_SQRT:
+ sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
+ break;
default:
break;
unary_op_func = unary_job_dispatcher_f32;
op_type = "scale-f32";
break;
+ case HTP_OP_SQR:
+ unary_op_func = unary_job_dispatcher_f32;
+ op_type = "sqr-f32";
+ break;
+ case HTP_OP_SQRT:
+ unary_op_func = unary_job_dispatcher_f32;
+ op_type = "sqrt-f32";
+ break;
default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);