| [IBM zDNN](docs/backend/zDNN.md) | IBM Z & LinuxONE |
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
-| [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon |
+| [Hexagon [In Progress]](docs/backend/snapdragon/README.md) | Snapdragon |
| [VirtGPU](docs/backend/VirtGPU.md) | VirtGPU APIR |
## Obtaining and quantizing models
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
- if (src0->type != GGML_TYPE_F32) {
- return false;
+ if (src0->type == GGML_TYPE_F32) {
+ if (src1->type != GGML_TYPE_F32) {
+ return false;
+ }
+ if (dst->type != GGML_TYPE_F32) {
+ return false;
+ }
}
- if (src1->type != GGML_TYPE_F32) {
- return false;
+ else if (src0->type == GGML_TYPE_F16) {
+ if (src1->type != GGML_TYPE_F16) {
+ return false;
+ }
+ if (dst->type != GGML_TYPE_F16) {
+ return false;
+ }
}
- if (dst->type != GGML_TYPE_F32) {
+ else {
return false;
}
+
if (!ggml_are_same_shape(src0, dst)) {
return false;
}
return HTP_STATUS_NO_SUPPORT;
}
- const uint32_t n_threads = octx->n_threads;
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+ const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
size_t src0_row_size = src0->nb[1];
size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used
return HTP_STATUS_OK;
}
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
-
// Prepare context
struct htp_act_context actx;
actx.octx = octx;
- actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ actx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
actx.src0_row_size = src0_row_size;
actx.src1_row_size = src1_row_size;
actx.data_src1 = data_src1;
actx.data_dst = (uint8_t *) dst->data;
- worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_threads);
return HTP_STATUS_OK;
}
return HTP_STATUS_NO_SUPPORT;
}
+ const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
+ const uint32_t n_threads = MIN(total_rows, octx->n_threads);
+
// Allocate scratchpad
// We need 1 row of float + 1 row of int32 per thread.
uint32_t ne00 = octx->src0.ne[0];
// Make sure we round up to 256 for alignment requirements
spad_per_thread = hex_round_up(spad_per_thread, 256);
- size_t total_spad_size = spad_per_thread * octx->n_threads;
+ size_t total_spad_size = spad_per_thread * n_threads;
if (octx->ctx->vtcm_size < total_spad_size) {
FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
octx->src0.data, octx->dst.data);
- uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
- uint32_t n_jobs = MIN(total_rows, octx->n_threads);
-
struct htp_argsort_context actx;
actx.octx = octx;
- actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
+ actx.nrows_per_thread = (total_rows + n_threads - 1) / n_threads;
// Run jobs
- worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_threads);
return HTP_STATUS_OK;
}
}
// Macro for scalar op switch
-#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \
- switch (octx->op) { \
- case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \
- case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \
- case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \
- case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \
- default: break; \
+#define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \
+ if(TYPE == HTP_TYPE_F32) { \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
+ case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
+ case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \
+ case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \
+ default: break; \
+ } \
+ } \
+ else { \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+ case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+ case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+ case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \
+ default: break; \
+ } \
}
// Macro for vector op switch (All Aligned)
-#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \
- switch (octx->op) { \
- case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
- case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
- case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
- case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
- default: break; \
+#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \
+ if(TYPE == HTP_TYPE_F32) { \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \
+ default: break; \
+ } \
+ } \
+ else { \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \
+ default: break; \
+ } \
}
// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned)
-#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \
- switch (octx->op) { \
- case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
- case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
- case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
- case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
- default: break; \
+#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \
+ if(TYPE == HTP_TYPE_F32) { \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \
+ default: break; \
+ } \
+ } \
+ else { \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \
+ default: break; \
+ } \
}
// Macro for vector op switch (All Unaligned - generic loop used in element repeat)
-#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \
- switch (octx->op) { \
- case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
- case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
- case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
- case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
- default: break; \
+#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \
+ if(TYPE == HTP_TYPE_F32) { \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \
+ default: break; \
+ } \
+ } \
+ else { \
+ switch (octx->op) { \
+ case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \
+ case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \
+ default: break; \
+ } \
}
// 1. Scalar src1 (ne10 == 1)
struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
+ const uint32_t src0_type = octx->src0.type;
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
- dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
}
for (uint32_t r = 0; r < current_block_size; r++) {
uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
- float val = *(float *)src1_ptr;
+ COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00);
src1_ptr += s1_stride;
- COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00);
}
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
- dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
- dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
}
ir += current_block_size;
struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
+ const uint32_t src0_type = octx->src0.type;
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
- dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
- dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
+ dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
}
uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned;
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
- COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
+ COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
}
uint32_t i03, i02, i01, rem;
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
- dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11;
- dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
- dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size);
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
+ dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
}
struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
+ const uint32_t src0_type = octx->src0.type;
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
- dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
}
uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned;
uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
- COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00);
+ COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00);
}
uint32_t i03, i02, i01, rem;
i02 = fastdiv(rem, &bctx->dim1_div);
i01 = rem - i02 * ne01;
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
- dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
p02 = fastdiv(prem, &bctx->dim1_div);
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
- dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
}
ir += current_block_size;
struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
+ const uint32_t src0_type = octx->src0.type;
+ const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16);
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
- dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
}
uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned;
// Read src1 from DDR (unaligned)
- COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00);
+ COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00);
}
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
- dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
p02 = fastdiv(prem, &bctx->dim1_div);
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
- dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
}
ir += current_block_size;
struct htp_ops_context * octx = bctx->octx;
htp_binary_preamble;
+ const uint32_t src0_type = octx->src0.type;
+ const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
+ const uint32_t row_size_bytes = ne00 * elem_size_bytes;;
const uint32_t total_rows = ne01 * ne02 * ne03;
const uint32_t start_row = bctx->nrows_per_thread * ith;
const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows);
uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half;
dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0);
- dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size);
ir_prefetch += current_block_size;
spad_idx ^= 1;
}
for (uint32_t c = 0; c < ne00; c += ne10) {
uint32_t len = MIN(ne10, ne00 - c);
// Use UUU for speed and simplicity
- COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len);
+ COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len);
}
}
uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1;
- dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size);
+ dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size);
if (ir_prefetch < end_row) {
uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02);
p02 = fastdiv(prem, &bctx->dim1_div);
p01 = prem - p02 * ne01;
uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01;
- dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size);
+ dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size);
ir_prefetch += next_block_size;
}
ir += current_block_size;
dma_queue_flush(q);
}
-static int execute_op_binary_f32(struct htp_ops_context * octx) {
+static int execute_op_binary(struct htp_ops_context * octx) {
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
struct htp_tensor * dst = &octx->dst;
- const uint32_t n_threads = octx->n_threads;
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+ const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
// Use packed row sizes for VTCM allocation
- const size_t src0_row_size = src0->ne[0] * sizeof(float);
- const size_t src1_row_size = src1->ne[0] * sizeof(float);
- const size_t dst_row_size = dst->ne[0] * sizeof(float);
+ const uint32_t src0_type = octx->src0.type;
+ const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16);
+ const size_t src0_row_size = src0->ne[0] * elem_size;
+ const size_t src1_row_size = src1->ne[0] * elem_size;
+ const size_t dst_row_size = dst->ne[0] * elem_size;
// Align to VLEN
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
bool is_scalar = !is_add_id && (src1->ne[0] == 1);
// Determine which kernel we will use to alloc memory and dispatch
- bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] &&
+ bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) &&
(src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) &&
(src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) &&
(src1->ne[3] == src0->ne[3] || src1->ne[3] == 1);
}
if (rows_per_buffer < 1) {
- FARF(ERROR, "binary-f32: VTCM too small\n");
+ FARF(ERROR, "binary: VTCM too small\n");
return HTP_STATUS_VTCM_TOO_SMALL;
}
return HTP_STATUS_OK;
}
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
-
dma_queue * q = octx->ctx->dma[0];
if (is_row_bcast) {
- dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1);
+ dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1);
}
struct htp_binary_context bctx;
bctx.octx = octx;
- bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
bctx.block_max = rows_per_buffer;
bctx.src0_row_size_aligned = src0_row_size_aligned;
bctx.src1_row_size_aligned = src1_row_size_aligned;
dma_queue_pop(q);
}
- worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads);
return HTP_STATUS_OK;
}
int op_binary(struct htp_ops_context * octx) {
- if (octx->src0.type == HTP_TYPE_F32) {
- return execute_op_binary_f32(octx);
+
+ // Does not support permutations of src1
+ const struct htp_tensor * src1 = &octx->src1;
+ if (src1->nb[1] < src1->nb[0]) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ const uint32_t src0_type = octx->src0.type;
+ if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) {
+ return execute_op_binary(octx);
}
+
return HTP_STATUS_NO_SUPPORT;
}
+
int op_cpy(struct htp_ops_context * octx) {
cpy_preamble;
+ const uint32_t n_threads = MIN(nr, octx->n_threads);
+
struct htp_copy_context ct;
ct.octx = octx;
const bool transposed = (nb00 > nb01) || (nb0 > nb1);
const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3);
- const uint32_t n_jobs = MIN(nr, octx->n_threads);
- ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+ ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
if (sametype && sameshape) {
ct.copy = cpy_thread_sametype_sameshape;
return HTP_STATUS_NO_SUPPORT;
}
- worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads);
return HTP_STATUS_OK;
}
int op_get_rows(struct htp_ops_context * octx) {
get_rows_preamble;
+ const uint32_t n_threads = MIN(nr, octx->n_threads);
+
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
- const uint32_t n_jobs = MIN(nr, octx->n_threads);
- grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+ grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads;
- worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads);
return HTP_STATUS_OK;
}
// Binary operations (add, mul, sub)
//
-#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \
+#define UNUSED(x) (void)(x)
+
+#define hvx_arith_loop_body(dst_type, src0_type, src1_type, elem_size, vec_store, vec_op) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src0_type * restrict vsrc0 = (src0_type *) src0; \
src1_type * restrict vsrc1 = (src1_type *) src1; \
\
- const uint32_t elem_size = sizeof(float); \
- const uint32_t epv = 128 / elem_size; \
+ const uint32_t epv = 128 / (elem_size); \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
} \
if (nloe) { \
HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \
- vec_store((void *) &vdst[i], nloe * elem_size, v); \
+ vec_store((void *) &vdst[i], nloe * (elem_size), v); \
} \
} while(0)
#if __HVX_ARCH__ < 79
-#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
-#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
-#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+
+#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
+#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+
#else
-#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
-#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
-#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+
+#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
+#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+
#endif
+#define HVX_OP_ADD_F16(a, b) hvx_vec_add_f16_f16(a, b)
+#define HVX_OP_SUB_F16(a, b) hvx_vec_sub_f16_f16(a, b)
+#define HVX_OP_MUL_F16(a, b) hvx_vec_mul_f16_f16(a, b)
+
// Generic macro to define alignment permutations for an op
-#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
+#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO, ELEM_TYPE) \
static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
- hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
+ hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
- hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
+ hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
- hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
+ hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
- hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
+ hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
- hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
+ hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
- hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
+ hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src1 % 128 == 0); \
- hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
+ hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
- hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
+ hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
} \
-DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
-DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
-DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD_F32, float)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB_F32, float)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL_F32, float)
+
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f16, HVX_OP_ADD_F16, _Float16)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f16, HVX_OP_SUB_F16, _Float16)
+DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f16, HVX_OP_MUL_F16, _Float16)
// Dispatcher logic
#define HVX_BINARY_DISPATCHER(OP_NAME) \
HVX_BINARY_DISPATCHER(hvx_sub_f32)
HVX_BINARY_DISPATCHER(hvx_mul_f32)
+HVX_BINARY_DISPATCHER(hvx_add_f16)
+HVX_BINARY_DISPATCHER(hvx_sub_f16)
+HVX_BINARY_DISPATCHER(hvx_mul_f16)
+
// Mul-Mul Optimized
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
assert((unsigned long) dst % 128 == 0);
_Pragma("unroll(4)")
for (; i < nvec; i++) {
- HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
+ HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);
vdst[i] = HVX_OP_MUL(v1, vsrc2[i]);
}
if (nloe) {
- HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]);
- HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]);
+ HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]);
+ HVX_Vector v2 = HVX_OP_MUL_F32(v1, vsrc2[i]);
hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2);
}
}
// Scalar Operations
-#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \
+#define hvx_scalar_loop_body(dst_type, src_type, elem_size, vec_store, scalar_op_macro) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
- const uint32_t elem_size = sizeof(float); \
- const uint32_t epv = 128 / elem_size; \
+ const uint32_t epv = 128 / (elem_size); \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
if (nloe) { \
HVX_Vector v = vsrc[i]; \
v = scalar_op_macro(v); \
- vec_store((void *) &vdst[i], nloe * elem_size, v); \
+ vec_store((void *) &vdst[i], nloe * (elem_size), v); \
} \
} while(0)
-#define HVX_OP_ADD_SCALAR(v) \
+#define HVX_OP_ADD_SCALAR_F32(v) \
({ \
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \
- HVX_Vector out = HVX_OP_ADD(v, val_vec); \
+ HVX_Vector out = HVX_OP_ADD_F32(v, val_vec); \
Q6_V_vmux_QVV(pred_inf, inf, out); \
})
-#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec)
-#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec)
-
-// Add Scalar Variants
-
-static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
- assert((unsigned long) dst % 128 == 0);
- assert((unsigned long) src % 128 == 0);
- hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
-}
-
-static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
- assert((unsigned long) dst % 128 == 0);
- hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR);
-}
-
-static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- const HVX_Vector inf = hvx_vec_splat_f32(INFINITY);
- assert((unsigned long) src % 128 == 0);
- hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
-}
-
-static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- static const float kInf = INFINITY;
- const HVX_Vector inf = hvx_vec_splat_f32(kInf);
- hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR);
-}
-
-// Sub Scalar Variants
+#define HVX_OP_MUL_SCALAR_F32(v) HVX_OP_MUL_F32(v, val_vec)
+#define HVX_OP_SUB_SCALAR_F32(v) HVX_OP_SUB_F32(v, val_vec)
-static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- assert((unsigned long) dst % 128 == 0);
- assert((unsigned long) src % 128 == 0);
- hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
-}
-
-static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- assert((unsigned long) dst % 128 == 0);
- hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR);
-}
-
-static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- assert((unsigned long) src % 128 == 0);
- hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
-}
-
-static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR);
-}
+#define HVX_OP_ADD_SCALAR_F16(v) \
+ ({ \
+ const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VhVh(inf, v); \
+ HVX_Vector out = HVX_OP_ADD_F16(v, val_vec); \
+ Q6_V_vmux_QVV(pred_inf, inf, out); \
+ })
-// Mul Scalar Variants
+#define HVX_OP_MUL_SCALAR_F16(v) HVX_OP_MUL_F16(v, val_vec)
+#define HVX_OP_SUB_SCALAR_F16(v) HVX_OP_SUB_F16(v, val_vec)
-static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- assert((unsigned long) dst % 128 == 0);
- assert((unsigned long) src % 128 == 0);
- hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
-}
+// Scalar Variants
-static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- assert((unsigned long) dst % 128 == 0);
- hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR);
-}
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(OP_NAME, OP_MACRO, SPLAT_MACRO, ELEM_TYPE) \
+static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+ const HVX_Vector val_vec = SPLAT_MACRO(val); \
+ const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+ assert((uintptr_t) dst % 128 == 0); \
+ assert((uintptr_t) src % 128 == 0); \
+ hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+ const HVX_Vector val_vec = SPLAT_MACRO(val); \
+ const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+ assert((uintptr_t) dst % 128 == 0); \
+ hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \
+} \
+static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+ const HVX_Vector val_vec = SPLAT_MACRO(val); \
+ const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+ assert((uintptr_t) src % 128 == 0); \
+ hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
+} \
+static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \
+ const HVX_Vector val_vec = SPLAT_MACRO(val); \
+ const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \
+ hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \
+} \
-static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- assert((unsigned long) src % 128 == 0);
- hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
-}
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f32, HVX_OP_ADD_SCALAR_F32, hvx_vec_splat_f32, float)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f32, HVX_OP_SUB_SCALAR_F32, hvx_vec_splat_f32, float)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f32, HVX_OP_MUL_SCALAR_F32, hvx_vec_splat_f32, float)
-static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
- const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR);
-}
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f16, HVX_OP_ADD_SCALAR_F16, hvx_vec_splat_f16, _Float16)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f16, HVX_OP_SUB_SCALAR_F16, hvx_vec_splat_f16, _Float16)
+DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f16, HVX_OP_MUL_SCALAR_F16, hvx_vec_splat_f16, _Float16)
-static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
- if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
- hvx_add_scalar_f32_aa(dst, src, val, num_elems);
- } else if (hex_is_aligned((void *) dst, 128)) {
- hvx_add_scalar_f32_au(dst, src, val, num_elems);
- } else if (hex_is_aligned((void *) src, 128)) {
- hvx_add_scalar_f32_ua(dst, src, val, num_elems);
- } else {
- hvx_add_scalar_f32_uu(dst, src, val, num_elems);
- }
+// Dispatcher logic
+#define HVX_BINARY_SCALAR_DISPATCHER(OP_NAME, ELEM_TYPE) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, const uint32_t num_elems) { \
+ if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \
+ OP_NAME##_aa(dst, src, val, num_elems); \
+ } else if (hex_is_aligned((void *) dst, 128)) { \
+ OP_NAME##_au(dst, src, val, num_elems); \
+ } else if (hex_is_aligned((void *) src, 128)) { \
+ OP_NAME##_ua(dst, src, val, num_elems); \
+ } else { \
+ OP_NAME##_uu(dst, src, val, num_elems); \
+ } \
}
-static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
- if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
- hvx_mul_scalar_f32_aa(dst, src, val, num_elems);
- } else if (hex_is_aligned((void *) dst, 128)) {
- hvx_mul_scalar_f32_au(dst, src, val, num_elems);
- } else if (hex_is_aligned((void *) src, 128)) {
- hvx_mul_scalar_f32_ua(dst, src, val, num_elems);
- } else {
- hvx_mul_scalar_f32_uu(dst, src, val, num_elems);
- }
-}
+HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f32, float)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f32, float)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f32, float)
-static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
- if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) {
- hvx_sub_scalar_f32_aa(dst, src, val, num_elems);
- } else if (hex_is_aligned((void *) dst, 128)) {
- hvx_sub_scalar_f32_au(dst, src, val, num_elems);
- } else if (hex_is_aligned((void *) src, 128)) {
- hvx_sub_scalar_f32_ua(dst, src, val, num_elems);
- } else {
- hvx_sub_scalar_f32_uu(dst, src, val, num_elems);
- }
-}
+HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f16, _Float16)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f16, _Float16)
+HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f16, _Float16)
// MIN Scalar variants
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
- hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
+ hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) dst % 128 == 0);
- hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR);
+ hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
assert((unsigned long) src % 128 == 0);
- hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
+ hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) {
const HVX_Vector val_vec = hvx_vec_splat_f32(val);
- hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR);
+ hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR);
}
static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) {
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
- hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
+ hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
assert((unsigned long) dst % 128 == 0);
- hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
+ hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
assert((unsigned long) src % 128 == 0);
- hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
+ hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) {
const HVX_Vector min_vec = hvx_vec_splat_f32(min);
const HVX_Vector max_vec = hvx_vec_splat_f32(max);
- hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
+ hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR);
}
static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) {
// Square
//
-#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
+#define hvx_sqr_f32_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
- vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
+ vdst[i] = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \
} \
if (nloe) { \
- HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
+ HVX_Vector v = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \
vec_store((void *) &vdst[i], nloe * elem_size, v); \
} \
} while(0)
static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
- hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+ hvx_sqr_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
- hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+ hvx_sqr_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
- hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+ hvx_sqr_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
- hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+ hvx_sqr_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
}
}
-#undef HVX_OP_ADD
-#undef HVX_OP_SUB
-#undef HVX_OP_MUL
+#undef HVX_OP_ADD_F32
+#undef HVX_OP_SUB_F32
+#undef HVX_OP_MUL_F32
+#undef HVX_OP_ADD_F16
+#undef HVX_OP_SUB_F16
+#undef HVX_OP_MUL_F16
#undef hvx_arith_loop_body
-#undef HVX_OP_ADD_SCALAR
-#undef HVX_OP_SUB_SCALAR
-#undef HVX_OP_MUL_SCALAR
+#undef HVX_OP_ADD_SCALAR_F32
+#undef HVX_OP_SUB_SCALAR_F32
+#undef HVX_OP_MUL_SCALAR_F32
+#undef HVX_OP_ADD_SCALAR_F16
+#undef HVX_OP_SUB_SCALAR_F16
+#undef HVX_OP_MUL_SCALAR_F16
#undef hvx_scalar_loop_body
#undef HVX_OP_MIN_SCALAR
#undef HVX_OP_CLAMP_SCALAR
#undef DEFINE_HVX_BINARY_OP_VARIANTS
#undef HVX_BINARY_DISPATCHER
+#undef UNUSED
#endif // HVX_ARITH_H
#endif
+#if __HVX_ARCH__ < 79
+
+static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+ const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16
+ const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16
+ HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);
+ HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);
+ HVX_Vector a0 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));
+ HVX_Vector a1 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));
+ return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));
+}
+
+static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+ const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16
+ const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16
+ HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one);
+ HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone);
+ HVX_Vector a0 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p));
+ HVX_Vector a1 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p));
+ return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0));
+}
+
+static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+ return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b));
+}
+
+#else
+
+static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+ return Q6_Vhf_vadd_VhfVhf(a, b);
+}
+
+static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+ return Q6_Vhf_vsub_VhfVhf(a, b);
+}
+
+static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b)
+{
+ return Q6_Vhf_vmpy_VhfVhf(a, b);
+}
+
+#endif // __HVX_ARCH__ < 79
+
#endif /* HVX_BASE_H */
#include "hvx-arith.h"
#if __HVX_ARCH__ < 79
-#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
-#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
+// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.
+static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX_Vector vec2_sf_const, HVX_Vector vec_hf_one_1_0) {
+#if __HVX_ARCH__ < 79
+ HVX_VectorPair src_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);
+ HVX_Vector src_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(src_to_f32));
+ HVX_Vector src_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(src_to_f32));
+#else
+ HVX_VectorPair src_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0);
+ HVX_Vector src_to_f32_0 = Q6_V_lo_W(src_to_f32);
+ HVX_Vector src_to_f32_1 = Q6_V_hi_W(src_to_f32);
+#endif
+
+ HVX_Vector div_f32_0 = HVX_OP_MUL_F32(src_to_f32_0, vec2_sf_const);
+ HVX_Vector div_f32_1 = HVX_OP_MUL_F32(src_to_f32_1, vec2_sf_const);
+
+#if __HVX_ARCH__ < 79
+ HVX_Vector res = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);
+#else
+ HVX_Vector res = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);
+#endif
+ return res;
+}
+
+#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
+ \
+ const uint32_t nvec = n / VLEN_FP16; \
+ const uint32_t nloe = n % VLEN_FP16; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
+ vdst[i] = res; \
+ } \
+ if (nloe) { \
+ HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
+ vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
+ } \
+ } while(0)
+
+static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+ const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+ assert((uintptr_t) dst % 128 == 0);
+ assert((uintptr_t) src % 128 == 0);
+ hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
+}
+static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+ const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+ assert((uintptr_t) dst % 128 == 0);
+ hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
+}
+static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+ const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+ assert((uintptr_t) src % 128 == 0);
+ hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+}
+static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
+ const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+ hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
+}
+
+// Compute div by using hvx_vec_inverse_f32_guard. Requires first by exapnding fp32 to fp16 and convert the result back to fp32.
+static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector vec_hf_one_1_0) {
+#if __HVX_ARCH__ < 79
+ // Convert first input to fp32
+ HVX_VectorPair vec1_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0
+ HVX_Vector vec1_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec1_to_f32));
+ HVX_Vector vec1_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec1_to_f32));
+
+ // Convert second input to fp32
+ HVX_VectorPair vec2_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0
+ HVX_Vector vec2_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec2_to_f32));
+ HVX_Vector vec2_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec2_to_f32));
+#else
+ // Convert first input to fp32
+ HVX_VectorPair vec1_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0
+ HVX_Vector vec1_to_f32_0 = Q6_V_lo_W(vec1_to_f32);
+ HVX_Vector vec1_to_f32_1 = Q6_V_hi_W(vec1_to_f32);
+
+ // Convert second input to fp32
+ HVX_VectorPair vec2_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0
+ HVX_Vector vec2_to_f32_0 = Q6_V_lo_W(vec2_to_f32);
+ HVX_Vector vec2_to_f32_1 = Q6_V_hi_W(vec2_to_f32);
+#endif
+
+ // Inverse second input in fp32
+ HVX_Vector vec2_inv_f32_0 = hvx_vec_inverse_f32_guard(vec2_to_f32_0, f32_nan_inf_mask);
+ HVX_Vector vec2_inv_f32_1 = hvx_vec_inverse_f32_guard(vec2_to_f32_1, f32_nan_inf_mask);
+
+ // Multiply first input by inverse of second, in fp32
+ HVX_Vector div_f32_0 = HVX_OP_MUL_F32(vec1_to_f32_0, vec2_inv_f32_0);
+ HVX_Vector div_f32_1 = HVX_OP_MUL_F32(vec1_to_f32_1, vec2_inv_f32_1);
+
+ // Convert back to fp16
+#if __HVX_ARCH__ < 79
+ HVX_Vector recip = hvx_vec_f32_to_f16(div_f32_0, div_f32_1);
+#else
+ HVX_Vector recip = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1);
+#endif
+
+ return recip;
+}
+
+#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src0_type * restrict vsrc0 = (src0_type *) src0; \
+ src1_type * restrict vsrc1 = (src1_type *) src1; \
+ \
+ const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
+ const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
+ \
+ const uint32_t nvec = n / VLEN_FP16; \
+ const uint32_t nloe = n % VLEN_FP16; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
+ vdst[i] = res; \
+ } \
+ if (nloe) { \
+ HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
+ vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
+ } \
+ } while(0)
+
#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
- HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
+ HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \
vdst[i] = res; \
} \
if (nloe) { \
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
- HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
+ HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \
} \
} while(0)
-// 3-letter suffix variants
-static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((uintptr_t) dst % 128 == 0);
- assert((uintptr_t) src0 % 128 == 0);
- assert((uintptr_t) src1 % 128 == 0);
- hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
-}
-
-static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((uintptr_t) dst % 128 == 0);
- assert((uintptr_t) src0 % 128 == 0);
- hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_DIV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \
+static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ assert((uintptr_t) src0 % 128 == 0); \
+ assert((uintptr_t) src1 % 128 == 0); \
+ OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ assert((uintptr_t) src0 % 128 == 0); \
+ OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ assert((uintptr_t) src1 % 128 == 0); \
+ OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) src0 % 128 == 0); \
+ assert((uintptr_t) src1 % 128 == 0); \
+ OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) src0 % 128 == 0); \
+ OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ assert((uintptr_t) src1 % 128 == 0); \
+ OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
+ OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); \
+} \
+
+// Dispatcher logic
+#define HVX_DIV_DISPATCHER(OP_NAME) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
+ if (hex_is_aligned((void *) dst, 128)) { \
+ if (hex_is_aligned((void *) src0, 128)) { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
+ else OP_NAME##_aau(dst, src0, src1, num_elems); \
+ } else { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
+ else OP_NAME##_auu(dst, src0, src1, num_elems); \
+ } \
+ } else { \
+ if (hex_is_aligned((void *) src0, 128)) { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
+ else OP_NAME##_uau(dst, src0, src1, num_elems); \
+ } else { \
+ if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
+ else OP_NAME##_uuu(dst, src0, src1, num_elems); \
+ } \
+ } \
}
-static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((uintptr_t) dst % 128 == 0);
- assert((uintptr_t) src1 % 128 == 0);
- hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
-}
+DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f32, hvx_div_f32_loop_body)
+DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f16, hvx_div_f16_loop_body)
-static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((uintptr_t) dst % 128 == 0);
- hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
-}
-
-static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((uintptr_t) src0 % 128 == 0);
- assert((uintptr_t) src1 % 128 == 0);
- hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
-}
-
-static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((uintptr_t) src0 % 128 == 0);
- hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
-}
-
-static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- assert((uintptr_t) src1 % 128 == 0);
- hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
-}
-
-static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
- hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
-}
-
-static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
- if (hex_is_aligned((void *) dst, 128)) {
- if (hex_is_aligned((void *) src0, 128)) {
- if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
- else hvx_div_f32_aau(dst, src0, src1, num_elems);
- } else {
- if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
- else hvx_div_f32_auu(dst, src0, src1, num_elems);
- }
- } else {
- if (hex_is_aligned((void *) src0, 128)) {
- if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
- else hvx_div_f32_uau(dst, src0, src1, num_elems);
- } else {
- if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
- else hvx_div_f32_uuu(dst, src0, src1, num_elems);
- }
- }
-}
+HVX_DIV_DISPATCHER(hvx_div_f32)
+HVX_DIV_DISPATCHER(hvx_div_f16)
-#undef HVX_OP_MUL
+#undef HVX_OP_MUL_F32
#endif // HVX_DIV_H
} \
} while(0)
-static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
- assert((unsigned long) dst % 128 == 0);
- assert((unsigned long) src % 128 == 0);
- hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
-}
+static inline HVX_Vector hvx_vec_inverse_f16_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) {
+ HVX_Vector out = hvx_vec_inverse_f16(v_sf);
-static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
- assert((unsigned long) dst % 128 == 0);
- hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
-}
+ HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask);
+ const HVX_VectorPred pred = Q6_Q_vcmp_eq_VhVh(nan_inf_mask, masked_out);
-static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
- assert((unsigned long) src % 128 == 0);
- hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
+ return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out);
}
-static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
- hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
-}
+#define hvx_inverse_f16_loop_body(dst_type, src_type, vec_store) \
+ do { \
+ dst_type * restrict vdst = (dst_type *) dst; \
+ src_type * restrict vsrc = (src_type *) src; \
+ \
+ const HVX_Vector nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \
+ \
+ const uint32_t nvec = n / VLEN_FP16; \
+ const uint32_t nloe = n % VLEN_FP16; \
+ \
+ uint32_t i = 0; \
+ \
+ _Pragma("unroll(4)") \
+ for (; i < nvec; i++) { \
+ vdst[i] = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \
+ } \
+ if (nloe) { \
+ HVX_Vector v = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \
+ vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, v); \
+ } \
+ } while(0)
-static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) {
- if ((unsigned long) dst % 128 == 0) {
- if ((unsigned long) src % 128 == 0) {
- hvx_inverse_f32_aa(dst, src, num_elems);
- } else {
- hvx_inverse_f32_au(dst, src, num_elems);
- }
- } else {
- if ((unsigned long) src % 128 == 0) {
- hvx_inverse_f32_ua(dst, src, num_elems);
- } else {
- hvx_inverse_f32_uu(dst, src, num_elems);
- }
- }
+// Generic macro to define alignment permutations for an op
+#define DEFINE_HVX_INV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \
+static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ assert((uintptr_t) src % 128 == 0); \
+ OP_LOOP_BODY(HVX_Vector, HVX_Vector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+ assert((uintptr_t) dst % 128 == 0); \
+ OP_LOOP_BODY(HVX_Vector, HVX_UVector, hvx_vec_store_a); \
+} \
+static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+ assert((uintptr_t) src % 128 == 0); \
+ OP_LOOP_BODY(HVX_UVector, HVX_Vector, hvx_vec_store_u); \
+} \
+static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \
+ OP_LOOP_BODY(HVX_UVector, HVX_UVector, hvx_vec_store_u); \
+} \
+
+// Dispatcher logic
+#define HVX_INV_DISPATCHER(OP_NAME) \
+static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { \
+ if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \
+ OP_NAME##_aa(dst, src, num_elems); \
+ } else if (hex_is_aligned((void *) dst, 128)) { \
+ OP_NAME##_au(dst, src, num_elems); \
+ } else if (hex_is_aligned((void *) src, 128)) { \
+ OP_NAME##_ua(dst, src, num_elems); \
+ } else { \
+ OP_NAME##_uu(dst, src, num_elems); \
+ } \
}
+DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f32, hvx_inverse_f32_loop_body)
+DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f16, hvx_inverse_f16_loop_body)
+
+HVX_INV_DISPATCHER(hvx_inverse_f32)
+HVX_INV_DISPATCHER(hvx_inverse_f16)
+
#endif // HVX_INVERSE_H
return HTP_STATUS_NO_SUPPORT;
}
- const uint32_t n_threads = octx->n_threads;
+ const uint32_t ne0 = dst->ne[0];
+ const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+ const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
const size_t src0_row_size = src0->nb[1];
const size_t dst_row_size = dst->nb[1];
rctx.dst_row_size_aligned = dst_row_size_aligned;
rctx.theta_cache_offset = theta_cache_size_aligned;
- uint32_t ne0 = dst->ne[0];
- uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
rctx.src0_nrows = src0_nrows;
+ rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
- rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
- worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_threads);
}
return err;
int op_set_rows(struct htp_ops_context * octx) {
set_rows_preamble;
+ const uint32_t n_threads = MIN(nr, octx->n_threads);
+
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
srctx.div_ne12 = init_fastdiv_values(ne12);
srctx.div_ne11 = init_fastdiv_values(ne11);
- const uint32_t n_jobs = MIN(nr, octx->n_threads);
- srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+ srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;
switch(octx->dst.type) {
case HTP_TYPE_F32:
- worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads);
break;
case HTP_TYPE_F16:
- worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_threads);
break;
default:
return HTP_STATUS_NO_SUPPORT;
return HTP_STATUS_NO_SUPPORT;
}
- const uint32_t n_threads = octx->n_threads;
+ const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+ const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
const size_t src0_row_size = src0->nb[1];
const size_t src1_row_size = src0_row_size;
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
- uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
-
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
- smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
- worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs);
+ smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
+ worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads);
}
return err;
return HTP_STATUS_OK;
}
- const int n_threads = octx->n_threads;
const uint32_t src0_nrows = ne01 * ne02 * ne03;
-
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
- uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
+ const uint32_t rows_per_thread = (src0_nrows + n_threads - 1) / n_threads;
bool opt_path = false;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
.opt_path = opt_path,
};
- worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_threads);
return HTP_STATUS_OK;
}
return HTP_STATUS_NO_SUPPORT;
}
- const int n_threads = octx->n_threads;
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+ const uint32_t n_threads = MIN(octx->n_threads, src0_nrows);
const size_t src0_row_size = src0->nb[1];
const size_t dst_row_size = dst->nb[1];
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
-
struct htp_unary_context uctx = {
.octx = octx,
- .src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs,
+ .src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads,
.src0_nrows = src0_nrows,
.data_src0 = (const uint8_t *)src0->data,
.nc = src0->ne[0],
};
- worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads);
}
return err;