static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
#define htp_binary_preamble \
+ const struct htp_tensor * src0 = &octx->src0; \
+ const struct htp_tensor * src1 = &octx->src1; \
+ const struct htp_tensor * src2 = &octx->src2; \
+ struct htp_tensor * dst = &octx->dst; \
+ \
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
const uint32_t ne02 = src0->ne[2]; \
const uint32_t nb0 = dst->nb[0]; \
const uint32_t nb1 = dst->nb[1]; \
const uint32_t nb2 = dst->nb[2]; \
- const uint32_t nb3 = dst->nb[3];
-
-static void binary_job_f32_per_thread(const struct htp_tensor * src0,
- const struct htp_tensor * src1,
- struct htp_tensor * dst,
- uint8_t * spad_data,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- enum htp_op op) {
+ const uint32_t nb3 = dst->nb[3]; \
+ \
+ const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+
+static void binary_job_f32_per_thread(struct htp_ops_context * octx,
+ uint8_t * spad_data,
+ uint32_t nth,
+ uint32_t ith,
+ enum htp_op op) {
htp_binary_preamble;
const size_t src0_row_size = nb01;
uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
- const uint32_t nr0 = ne00 / ne10;
-
const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
- const uint8_t * restrict src1_ptr = NULL;
+
+ const uint32_t ne02_ne01 = ne02 * ne01;
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
- src1_ptr = data_src1 + (ir % src1_nrows) * src1_row_size;
+ const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
+ const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
+ const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
+
+ const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3);
+ const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2);
+ const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1);
+
+ const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size;
if (ir + 1 < src0_end_row) {
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
}
}
+ const uint32_t nr0 = ne00 / ne10;
if (nr0 > 1) {
if ((1 == is_aligned) && (nr0 == ne00)) {
hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
- const struct htp_tensor * src1,
- const struct htp_tensor * src2,
- struct htp_tensor * dst,
- uint8_t * spad_data,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- hvx_elemwise_f32_func func_HVX) {
+static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx,
+ uint8_t * spad_data,
+ uint32_t nth,
+ uint32_t ith,
+ hvx_elemwise_f32_func func_HVX) {
htp_binary_preamble;
const size_t src0_row_size = nb01;
const size_t src1_row_size = nb11;
const size_t dst_row_size = nb1;
- const uint32_t ne02_ne01 = ne02 * ne01;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
+ const uint32_t ne02_ne01 = ne02 * ne01;
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
// src0 indices
- const uint32_t i03 = ir / ne02_ne01;
- const uint32_t i02 = (ir - i03 * ne02_ne01) / ne01;
+ const uint32_t i03 = fastdiv(ir, &octx->src0_div21);
+ const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1);
const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
// src1 indices
case HTP_OP_MUL:
case HTP_OP_ADD:
case HTP_OP_SUB:
- binary_job_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->src1_spad.data, n, i,
- octx->src0_nrows_per_thread, octx->op);
+ binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op);
break;
case HTP_OP_ADD_ID:
- binary_add_id_job_f32_per_thread(&octx->src0, &octx->src1, &octx->src2, &octx->dst, octx->src0_spad.data, n,
- i, octx->src0_nrows_per_thread, hvx_add_f32);
+ binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32);
break;
default:
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]);
+ octx->src0_div3 = init_fastdiv_values(src0->ne[3]);
+ octx->src0_div2 = init_fastdiv_values(src0->ne[2]);
+ octx->src0_div1 = init_fastdiv_values(src0->ne[1]);
+
+ octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]);
+ octx->src1_div3 = init_fastdiv_values(src1->ne[3]);
+ octx->src1_div2 = init_fastdiv_values(src1->ne[2]);
+ octx->src1_div1 = init_fastdiv_values(src1->ne[1]);
+
worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
}
return m * ((n + m - 1) / m);
}
+// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
+// Precompute mp (m' in the paper) and L such that division
+// can be computed using a multiply (high 32b of 64b result)
+// and a shift:
+//
+// n/d = (mulhi(n, mp) + n) >> L;
+struct fastdiv_values {
+ uint32_t mp;
+ uint32_t l;
+};
+
+static inline struct fastdiv_values init_fastdiv_values(uint32_t d) {
+ struct fastdiv_values result = { 0, 0 };
+ // compute L = ceil(log2(d));
+ while (result.l < 32 && ((uint32_t) 1 << result.l) < d) {
+ ++(result.l);
+ }
+
+ result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1);
+ return result;
+}
+
+static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) {
+ // Compute high 32 bits of n * mp
+ const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp)
+ // add n, apply bit shift
+ return (hi + n) >> vals->l;
+}
+
+static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) {
+ return n - fastdiv(n, vals) * d;
+}
+
static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));