return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
}
-static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
- if (x->ne[0] != y->ne[0]) {
- return false;
- }
- if (x->ne[1] != y->ne[1]) {
- return false;
- }
- if (x->ne[2] != y->ne[2]) {
- return false;
- }
- if (x->ne[3] != y->ne[3]) {
- return false;
- }
-
- return true;
-}
-
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
return opt_experimental;
}
-static bool hex_supported_src0_type(ggml_type t) {
- return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_src1_type(ggml_type t) {
- return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_src2_type(ggml_type t) {
- return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_src1_type2(ggml_type t) {
- return t == GGML_TYPE_F16;
-}
-
-static bool hex_supported_src1_type3(ggml_type t) {
- return t == GGML_TYPE_I32;
-}
-
-static bool hex_supported_dst_type(ggml_type t) {
- return t == GGML_TYPE_F32;
-}
-
-static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
- // TODO: support broadcast for ne[2 and 3]
- if (x->ne[0] != y->ne[0]) {
- return false;
- }
- if (x->ne[2] != y->ne[2]) {
- return false;
- }
- if (x->ne[3] != y->ne[3]) {
- return false;
- }
- return true;
-}
static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
- if (!hex_supported_src0_type(src0->type)) {
+ if (src0->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_src1_type(src1->type)) {
+ if (src1->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_dst_type(dst->type)) {
+ if (dst->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_dims2(src0, dst)) {
+ if (!ggml_are_same_shape(src0, dst)) {
return false;
}
- if (!ggml_can_repeat(src1, src0)) {
+ if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) {
return false;
}
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
- if (!hex_supported_src0_type(src0->type)) {
+ if (src0->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_src1_type(src1->type)) {
+ if (src1->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_dst_type(dst->type)) {
+ if (dst->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_dims2(src0, dst)) {
+ if (!ggml_are_same_shape(src0, dst)) {
return false;
}
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;
- if (!hex_supported_src0_type(src0->type)) {
+ if (src0->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_dst_type(dst->type)) {
+ if (dst->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_dims2(src0, dst)) {
+ if (!ggml_are_same_shape(src0, dst)) {
return false;
}
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;
- if (!hex_supported_src0_type(src0->type)) {
+ if (src0->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_dst_type(dst->type)) {
+ if (dst->type != GGML_TYPE_F32) {
return false;
}
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * dst = op;
- if (!hex_supported_src0_type(src0->type)) {
+ if (src0->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_dst_type(dst->type)) {
+ if (dst->type != GGML_TYPE_F32) {
return false;
}
}
if (src1) {
- if (!hex_supported_src1_type(src1->type)) {
+ if (src1->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_dims2(src0, src1)) {
+ if (!ggml_are_same_shape(src0, src1)) {
return false;
}
if (!ggml_is_contiguous(src1)) {
return false; // FIXME: add support for sinks
}
- if (!hex_supported_src0_type(src0->type)) {
+ if (src0->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_dst_type(dst->type)) {
+ if (dst->type != GGML_TYPE_F32) {
return false;
}
if (src1) {
- if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
+ if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
return false;
}
if (src0->ne[0] != src1->ne[0]) {
const struct ggml_tensor * src2 = op->src[2];
const struct ggml_tensor * dst = op;
- if (!hex_supported_src0_type(src0->type)) {
+ if (src0->type != GGML_TYPE_F32) {
return false; // FIXME: add support for GGML_TYPE_F16 for src0
}
- if (!hex_supported_dst_type(dst->type)) {
+ if (dst->type != GGML_TYPE_F32) {
return false;
}
- if (!hex_supported_src1_type3(src1->type)) {
+ if (src1->type != GGML_TYPE_I32) {
return false;
}
if (src2) {
- if (!hex_supported_src2_type(src2->type)) {
+ if (src2->type != GGML_TYPE_F32) {
return false;
}
int n_dims = op_params[1];
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
-static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
- const struct htp_tensor * src1,
- struct htp_tensor * dst,
- const int32_t * op_params,
- struct htp_spad * src0_spad,
- struct htp_spad * src1_spad,
- struct htp_spad * dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+struct htp_act_context {
+ struct htp_ops_context * octx;
+
+ // Precomputed values
+ const uint8_t * data_src0;
+ const uint8_t * data_src1;
+ uint8_t * data_dst;
+
+ size_t src0_row_size;
+ size_t src1_row_size;
+ size_t dst_row_size;
+
+ size_t src0_row_size_aligned;
+ size_t src1_row_size_aligned;
+ size_t dst_row_size_aligned;
+
+ size_t src0_spad_half_size;
+ size_t src1_spad_half_size;
+ size_t dst_spad_half_size;
+
+ uint32_t block;
+ uint32_t src0_nrows;
+ uint32_t src0_nrows_per_thread;
+ int nc;
+};
+
+static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_act_context * actx = (struct htp_act_context *) data;
+ const struct htp_tensor * src0 = &actx->octx->src0;
+ const struct htp_tensor * src1 = &actx->octx->src1;
+ const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble3;
- size_t src0_row_size = nb01;
- size_t src1_row_size = nb11;
- size_t dst_row_size = nb1;
-
-
-
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+ size_t src0_row_size = actx->src0_row_size;
+ size_t src1_row_size = actx->src1_row_size;
+ size_t dst_row_size = actx->dst_row_size;
+ const uint32_t src0_nrows = actx->src0_nrows;
+ const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
- const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
- const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
- uint8_t * restrict data_dst = (uint8_t *) dst->data;
-
- const bool src1_valid = src1->ne[0];
- const int nc = (src1_valid) ? ne00 : ne00 / 2;
- if (!src1_valid) {
- const int32_t swapped = op_params[1];
- data_src1 = data_src0;
- src1_row_size = src0_row_size;
+ const uint8_t * restrict data_src0 = actx->data_src0;
+ const uint8_t * restrict data_src1 = actx->data_src1;
+ uint8_t * restrict data_dst = actx->data_dst;
- const size_t nc_in_bytes = nc * SIZEOF_FP32;
- data_src0 += swapped ? nc_in_bytes : 0;
- data_src1 += swapped ? 0 : nc_in_bytes;
- }
+ const int nc = actx->nc;
- const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
- const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
- const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+ const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+ const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
+ const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
- uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
- uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
- uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
+ uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+ uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
+ uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
- // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
- size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
- size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
- size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
+ size_t src0_spad_half_size = actx->src0_spad_half_size;
+ size_t src1_spad_half_size = actx->src1_spad_half_size;
+ size_t dst_spad_half_size = actx->dst_spad_half_size;
- const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+ const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR,
"swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
- src0_spad->size_per_thread, src0_row_size_aligned);
+ actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
+ dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
- const struct htp_tensor * src1,
- struct htp_tensor * dst,
- const int32_t * op_params,
- struct htp_spad * src0_spad,
- struct htp_spad * src1_spad,
- struct htp_spad * dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_act_context * actx = (struct htp_act_context *) data;
+ const struct htp_tensor * src0 = &actx->octx->src0;
+ const struct htp_tensor * src1 = &actx->octx->src1;
+ const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble3;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
- size_t src0_row_size = nb01;
- size_t src1_row_size = nb11;
- size_t dst_row_size = nb1;
+ size_t src0_row_size = actx->src0_row_size;
+ size_t src1_row_size = actx->src1_row_size;
+ size_t dst_row_size = actx->dst_row_size;
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+ const uint32_t src0_nrows = actx->src0_nrows;
+ const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
return;
}
- const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
- const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
- uint8_t * restrict data_dst = (uint8_t *) dst->data;
-
- const bool src1_valid = src1->ne[0];
- const int nc = (src1_valid) ? ne00 : ne00 / 2;
- if (!src1_valid) {
- const int32_t swapped = op_params[1];
- data_src1 = data_src0;
- src1_row_size = src0_row_size;
+ const uint8_t * restrict data_src0 = actx->data_src0;
+ const uint8_t * restrict data_src1 = actx->data_src1;
+ uint8_t * restrict data_dst = actx->data_dst;
- const size_t nc_in_bytes = nc * SIZEOF_FP32;
- data_src0 += swapped ? nc_in_bytes : 0;
- data_src1 += swapped ? 0 : nc_in_bytes;
- }
+ const int nc = actx->nc;
- const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
- const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
- const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+ const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+ const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
+ const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
- uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
- uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
- uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
+ uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+ uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
+ uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
- // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
- size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
- size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
- size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
+ size_t src0_spad_half_size = actx->src0_spad_half_size;
+ size_t src1_spad_half_size = actx->src1_spad_half_size;
+ size_t dst_spad_half_size = actx->dst_spad_half_size;
- const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+ const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR,
"swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
"%zu\n",
- src0_spad->size_per_thread, src0_row_size_aligned);
+ actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
- const float alpha = ((const float *) (op_params))[2];
- const float limit = ((const float *) (op_params))[3];
+ const float alpha = ((const float *) (actx->octx->op_params))[2];
+ const float limit = ((const float *) (actx->octx->op_params))[3];
+
+ dma_queue * dma_queue = actx->octx->ctx->dma[ith];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
}
-static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
- struct htp_tensor * dst,
- const int32_t * op_params,
- struct htp_spad * src0_spad,
- struct htp_spad * dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_act_context * actx = (struct htp_act_context *) data;
+ const struct htp_tensor * src0 = &actx->octx->src0;
+ const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble2;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
- const size_t src0_row_size = nb01;
- const size_t dst_row_size = nb1;
- const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
- const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+ const size_t src0_row_size = actx->src0_row_size;
+ const size_t dst_row_size = actx->dst_row_size;
+ const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+ const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
- const uint32_t src0_nrows = ne01 * ne02 * ne03;
+ const uint32_t src0_nrows = actx->src0_nrows;
+ const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
return;
}
- const uint8_t * data_src0 = (const uint8_t *) src0->data;
- uint8_t * data_dst = (uint8_t *) dst->data;
+ const uint8_t * data_src0 = actx->data_src0;
+ uint8_t * data_dst = actx->data_dst;
- uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
- uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
+ // nc/ne0 matches.
+ const int ne0_val = actx->nc; // == dst->ne[0]
- // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
- size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
- size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
+ uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+ uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
+
+ size_t src0_spad_half_size = actx->src0_spad_half_size;
+ size_t dst_spad_half_size = actx->dst_spad_half_size;
// In gelu = x*sigmoid(x*1.702)
- const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+ const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
- src0_spad->size_per_thread, src0_row_size_aligned);
+ actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
+ dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// gelu = x * sigmoid(1.702 * x) // current implementation
- hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
- hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
- hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
+ hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val);
+ hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
+ hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
- unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
-
-static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
- struct htp_tensor * dst,
- const int32_t * op_params,
- struct htp_spad * src0_spad,
- struct htp_spad * dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_act_context * actx = (struct htp_act_context *) data;
+ const struct htp_tensor * src0 = &actx->octx->src0;
+ const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble2;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
- const size_t src0_row_size = nb01;
- const size_t dst_row_size = nb1;
- const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
- const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+ const size_t src0_row_size = actx->src0_row_size;
+ const size_t dst_row_size = actx->dst_row_size;
+ const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+ const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
- const uint32_t src0_nrows = ne01 * ne02 * ne03;
+ const uint32_t src0_nrows = actx->src0_nrows;
+ const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
return;
}
- const uint8_t * data_src0 = (const uint8_t *) src0->data;
- uint8_t * data_dst = (uint8_t *) dst->data;
+ const uint8_t * data_src0 = actx->data_src0;
+ uint8_t * data_dst = actx->data_dst;
- uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
- uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
+ const int ne0_val = actx->nc; // == dst->ne[0]
- // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
- size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
- size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
+ uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+ uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
- const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+ size_t src0_spad_half_size = actx->src0_spad_half_size;
+ size_t dst_spad_half_size = actx->dst_spad_half_size;
+
+ const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
- src0_spad->size_per_thread, src0_row_size_aligned);
+ actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
+ dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// silu = x * sigmoid(x)
- hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
- hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
+ hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val);
+ hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
static const float GELU_COEF_A = 0.044715f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
- const struct htp_tensor * src1,
- struct htp_tensor * dst,
- const int32_t * op_params,
- struct htp_spad * src0_spad,
- struct htp_spad * src1_spad,
- struct htp_spad * dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_act_context * actx = (struct htp_act_context *) data;
+ const struct htp_tensor * src0 = &actx->octx->src0;
+ const struct htp_tensor * src1 = &actx->octx->src1;
+ const struct htp_tensor * dst = &actx->octx->dst;
htp_act_preamble3;
- size_t src0_row_size = nb01;
- size_t src1_row_size = nb11;
- size_t dst_row_size = nb1;
+ size_t src0_row_size = actx->src0_row_size;
+ size_t src1_row_size = actx->src1_row_size;
+ size_t dst_row_size = actx->dst_row_size;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+ const uint32_t src0_nrows = actx->src0_nrows;
+ const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
return;
}
- const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
- const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
- uint8_t * restrict data_dst = (uint8_t *) dst->data;
+ const uint8_t * restrict data_src0 = actx->data_src0;
+ const uint8_t * restrict data_src1 = actx->data_src1;
+ uint8_t * restrict data_dst = actx->data_dst;
- const bool src1_valid = src1->ne[0];
- const int nc = (src1_valid) ? ne00 : ne00 / 2;
- if (!src1_valid) {
- const int32_t swapped = op_params[1];
- data_src1 = data_src0;
- src1_row_size = src0_row_size;
-
- const size_t nc_in_bytes = nc * SIZEOF_FP32;
- data_src0 += swapped ? nc_in_bytes : 0;
- data_src1 += swapped ? 0 : nc_in_bytes;
- }
+ const int nc = actx->nc;
- const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
- const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
- const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+ const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
+ const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
+ const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
- uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
- uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
- uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
+ uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
+ uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
+ uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
- // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
- size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
- size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
- size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
+ size_t src0_spad_half_size = actx->src0_spad_half_size;
+ size_t src1_spad_half_size = actx->src1_spad_half_size;
+ size_t dst_spad_half_size = actx->dst_spad_half_size;
- const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
+ const int BLOCK = actx->block;
if (BLOCK == 0) {
FARF(ERROR,
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
- src0_spad->size_per_thread, src0_row_size_aligned);
+ actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
return;
}
+ dma_queue * dma_queue = actx->octx->ctx->dma[ith];
+
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
- unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
-static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
- glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
- &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
-static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
- glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
- &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
-static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
- glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
- &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
-}
-
static int execute_op_activations_f32(struct htp_ops_context * octx) {
- int err = HTP_STATUS_OK;
-
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
struct htp_tensor * dst = &octx->dst;
switch (octx->op) {
case HTP_OP_UNARY_SILU:
- act_op_func = unary_silu_f32;
+ act_op_func = (worker_callback_t)unary_silu_f32_per_thread;
op_type = "silu-f32";
break;
case HTP_OP_GLU_SWIGLU:
- act_op_func = glu_swiglu_f32;
+ act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread;
op_type = "swiglu-f32";
break;
case HTP_OP_GLU_SWIGLU_OAI:
- act_op_func = glu_swiglu_oai_f32;
+ act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread;
op_type = "swiglu-oai-f32";
break;
case HTP_OP_UNARY_GELU:
- act_op_func = unary_gelu_f32;
+ act_op_func = (worker_callback_t)unary_gelu_f32_per_thread;
op_type = "gelu-f32";
break;
case HTP_OP_GLU_GEGLU:
- act_op_func = glu_geglu_f32;
+ act_op_func = (worker_callback_t)glu_geglu_f32_per_thread;
op_type = "geglu-f32";
break;
default:
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
}
- if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
- octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
- worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
+ if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ return HTP_STATUS_OK;
}
- return err;
+ uint32_t n_jobs = MIN(n_threads, src0_nrows);
+
+ // Prepare context
+ struct htp_act_context actx;
+ actx.octx = octx;
+
+ actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+
+ actx.src0_row_size = src0_row_size;
+ actx.src1_row_size = src1_row_size;
+ actx.dst_row_size = dst_row_size;
+
+ actx.src0_row_size_aligned = src0_row_size_aligned;
+ actx.src1_row_size_aligned = src1_row_size_aligned;
+ actx.dst_row_size_aligned = dst_row_size_aligned;
+
+ actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2;
+ actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2;
+ actx.dst_spad_half_size = octx->dst_spad.size_per_thread / 2;
+
+ actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned;
+ actx.src0_nrows = src0_nrows;
+
+ actx.nc = dst->ne[0];
+
+ // Pointers and GLU logic
+ const uint8_t * data_src0 = (const uint8_t *) src0->data;
+ const uint8_t * data_src1 = (const uint8_t *) src1->data;
+
+ if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
+ const int32_t swapped = octx->op_params[1];
+ data_src1 = data_src0;
+ actx.src1_row_size = actx.src0_row_size;
+
+ size_t nc_in_bytes = actx.nc * SIZEOF_FP32;
+ if (swapped) {
+ data_src0 += nc_in_bytes;
+ } else {
+ data_src1 += nc_in_bytes;
+ }
+ }
+
+ actx.data_src0 = data_src0;
+ actx.data_src1 = data_src1;
+ actx.data_dst = (uint8_t *) dst->data;
+
+ worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs);
+ return HTP_STATUS_OK;
}
int op_activations(struct htp_ops_context * octx) {
#include "htp-ops.h"
#include "hvx-utils.h"
+struct get_rows_context {
+ struct htp_ops_context * octx;
+ uint32_t src1_nrows_per_thread;
+ struct fastdiv_values get_rows_div_ne10;
+ struct fastdiv_values get_rows_div_ne10_ne11;
+};
+
#define get_rows_preamble \
const uint32_t ne00 = octx->src0.ne[0]; \
const uint32_t ne01 = octx->src0.ne[1]; \
\
const uint32_t nr = ne10 * ne11 * ne12;
-static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
+ struct get_rows_context * grctx = (struct get_rows_context *)data;
+ struct htp_ops_context * octx = grctx->octx;
get_rows_preamble;
// parallelize by src1 elements (which correspond to dst rows)
- const uint32_t dr = octx->src1_nrows_per_thread;
+ const uint32_t dr = grctx->src1_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i = ir0; i < ir1; ++i) {
- const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
+ const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);
const uint32_t rem = i - i12 * ne11 * ne10;
- const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
+ const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);
const uint32_t i10 = rem - i11 * ne10;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
}
-
- return HTP_STATUS_OK;
-}
-
-static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
- get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_get_rows(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
- octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
- octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
+ struct get_rows_context grctx;
+ grctx.octx = octx;
+ grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
+ grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
- octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+ grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
- worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs);
return HTP_STATUS_OK;
}
dmlink(q->tail, desc);
q->tail = desc;
- // FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
+ // FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
return true;
}
dptr = q->dptr[q->pop_idx];
- // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
+ // FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
return dptr;
}
+static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) {
+ dma_ptr dptr = { NULL };
+
+ if (q->push_idx == q->pop_idx) {
+ return dptr;
+ }
+
+ dptr = q->dptr[q->pop_idx];
+
+ // FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
+ q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
+ return dptr;
+}
+
+static inline bool dma_queue_empty(dma_queue * q) {
+ return q->push_idx == q->pop_idx;
+}
+
+static inline uint32_t dma_queue_depth(dma_queue * q) {
+ return (q->push_idx - q->pop_idx) & q->idx_mask;
+}
+
+static inline uint32_t dma_queue_capacity(dma_queue * q) {
+ return q->capacity;
+}
+
#ifdef __cplusplus
} // extern "C"
#endif
uint32_t src0_nrows_per_thread;
uint32_t src1_nrows_per_thread;
- struct fastdiv_values src0_div1; // fastdiv values for ne1
- struct fastdiv_values src0_div2; // fastdiv values for ne2
- struct fastdiv_values src0_div3; // fastdiv values for ne3
- struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
-
- struct fastdiv_values src1_div1; // fastdiv values for ne1
- struct fastdiv_values src1_div2; // fastdiv values for ne2
- struct fastdiv_values src1_div3; // fastdiv values for ne3
- struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
-
- struct fastdiv_values src3_div1; // fastdiv values for ne1
- struct fastdiv_values src3_div2; // fastdiv values for ne2
- struct fastdiv_values src3_div3; // fastdiv values for ne3
- struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
-
- struct fastdiv_values broadcast_rk2;
- struct fastdiv_values broadcast_rk3;
- struct fastdiv_values broadcast_rv2;
- struct fastdiv_values broadcast_rv3;
-
- struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
- struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
-
- struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
- struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
-
uint32_t flags;
};
struct fastdiv_values mm_div_r3;
};
-// vdelta control to replicate first 4x fp32 values across lanes
-static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = {
- 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
- 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
- 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
- 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
- 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
- 0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
- 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
-};
-
-// vdelta control to replicate and interleave first 8x fp32 values across lanes
-static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = {
- 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
- 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
- 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
- 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
- 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
- 0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
- 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
-};
-
-// vdelta control to replicate first fp32 value across all elements
-static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = {
- 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
- 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
- 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
- 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
- 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
- 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
- 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
-};
-
-// vdelta control to replicate first fp16 value across all elements
-static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = {
- 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
- 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
- 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
- 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
- 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
- 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
- 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-};
-
-// vdelta control to replicate first fp16 value across all elements
-static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = {
- 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
- 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
- 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
- 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
- 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
- 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
- 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
- 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
-};
-
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
// Convert to QF32
- HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
- HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
- HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
- HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
+ HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes
+ HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes
+ HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes
+ HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes
// Combine and convert to fp16
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
- // Replicate first fp16 scale across all lanes
- HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16;
- vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
- vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
-
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
// Compute max and scale
- HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
- HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf));
-
- // Replicate first fp16 scale across all lanes
- HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
- vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
- vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
+ HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes
+ HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
// Compute max and scale
HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
- vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
-
- // Replicate first fp16 scale across all lanes
- HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
- vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
+ vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes
HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);
#include "hex-dma.h"
#include "hvx-utils.h"
+#include "hex-fastdiv.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#define HTP_ROPE_TYPE_NORMAL 0
#define HTP_ROPE_TYPE_NEOX 2
+#define HTP_ROPE_SPAD_NROWS 16
+#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2)
+
#define htp_rope_preamble \
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
-struct rope_th_ctx {
+struct htp_rope_context {
int32_t n_dims;
int32_t mode;
int32_t n_ctx_orig;
float theta_scale;
float corr_dims[2];
+ uint32_t src0_nrows_per_thread;
+ size_t spad_stride;
+
struct htp_ops_context * octx;
+
+ size_t src0_row_size;
+ size_t dst_row_size;
+ size_t src0_row_size_aligned;
+ size_t dst_row_size_aligned;
+ size_t theta_cache_offset;
+ uint32_t src0_nrows;
+
+ uint64_t t_start;
};
static float rope_yarn_ramp(const float low, const float high, const int i0) {
dims[1] = MIN(n_dims - 1, end);
}
-static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
- memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
-
- const int32_t * op_params = &octx->op_params[0];
-
- rope_ctx->n_dims = ((const int32_t *) op_params)[1];
- rope_ctx->mode = ((const int32_t *) op_params)[2];
- rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
-
- memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
- memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
- memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
- memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
- memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
- memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
- memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
-
- rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
-
- rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
- rope_ctx->beta_slow, rope_ctx->corr_dims);
-
- rope_ctx->octx = octx;
- FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
- rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
-}
+static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
+ const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
+ const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
+ HVX_Vector * restrict vdst = (HVX_Vector *) dst;
-static void hvx_calc_rope_neox_f32(const float * restrict src0,
- float * restrict dst,
- const int num_elems,
- const float * restrict theta_cache) {
- // for (int i = 0; i < num_elems; i += 2) {
- //const float cos_theta = theta_cache[i + 0];
- //const float sin_theta = theta_cache[i + 1];
+ uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2
- //const float x0 = src[0];
- //const float x1 = src[num_elems/2];
+ uint32_t he = ne / 2; // half_dims offset in elements
+ uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors
- //dst[0] = x0*cos_theta - x1*sin_theta;
- //dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
+ #pragma unroll(2)
+ for (uint32_t i = 0; i < nvec; i += 2) {
+ HVX_Vector v0 = vsrc[i/2+0];
+ HVX_Vector v1 = vsrc[i/2+hv];
- //src += 1;
- //dst += 1;
- // }
-
- const uint8_t * restrict src0_curr = (const uint8_t *) src0;
- const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
- uint8_t * restrict dst_curr = (uint8_t *) dst;
-
- int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
- int half_size = (sizeof(float) * (num_elems / 2));
-
- for (int i = 0; i < step_of_1; i++) {
- HVX_Vector v0 = *(HVX_Vector *) src0_curr;
- HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
-
- HVX_Vector v2 = *(HVX_Vector *) theta_curr;
- HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
+ HVX_Vector v2 = vtheta[i+0];
+ HVX_Vector v3 = vtheta[i+1];
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
- *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
- *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
+ vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4);
+ vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5);
+ }
- src0_curr += VLEN;
- theta_curr += 2 * VLEN;
- dst_curr += VLEN;
+ for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
+ const float cos_theta = theta_cache[i+0];
+ const float sin_theta = theta_cache[i+1];
+ float x0 = src0[i/2];
+ float x1 = src0[i/2 + he];
+ dst[i/2] = x0 * cos_theta - x1 * sin_theta;
+ dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta;
}
}
-static void hvx_calc_rope_f32(const float * restrict src0,
- float * restrict dst,
- const int num_elems,
- const float * restrict theta_cache) {
- // for (int i = 0; i < num_elems; i += 2) {
- //const float cos_theta = theta_cache[i + 0];
- //const float sin_theta = theta_cache[i + 1];
-
- //const float x0 = src[0];
- //const float x1 = src[1];
+static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
+ const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
+ const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
+ HVX_Vector * restrict vdst = (HVX_Vector *) dst;
- //dst[0] = x0*cos_theta - x1*sin_theta;
- //dst[1] = x0*sin_theta + x1*cos_theta;
+ uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two
- //src += 2;
- //dst += 2;
- // }
+ #pragma unroll(2)
+ for (uint32_t i = 0; i < nvec; i+=2) {
+ HVX_Vector v0 = vsrc[i+0];
+ HVX_Vector v1 = vsrc[i+1];
- const uint8_t * restrict src0_curr = (const uint8_t *) src0;
- const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
- uint8_t * restrict dst_curr = (uint8_t *) dst;
-
- int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
-
- for (int i = 0; i < step_of_1; i++) {
- HVX_Vector v0 = *(HVX_Vector *) src0_curr;
- HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
-
- HVX_Vector v2 = *(HVX_Vector *) theta_curr;
- HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
+ HVX_Vector v2 = vtheta[i+0];
+ HVX_Vector v3 = vtheta[i+1];
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
- *(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore);
- *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
+ vdst[i+0] = Q6_V_lo_W(vstore);
+ vdst[i+1] = Q6_V_hi_W(vstore);
+ }
+
+ for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
+ const float cos_theta = theta_cache[i+0];
+ const float sin_theta = theta_cache[i+1];
+ float x0 = src0[i+0];
+ float x1 = src0[i+1];
+ dst[i+0] = x0 * cos_theta - x1 * sin_theta;
+ dst[i+1] = x0 * sin_theta + x1 * cos_theta;
+ }
+}
+
+static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
+ uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
+ #pragma unroll(4)
+ for (uint32_t i = 0; i < nr; i++) {
+ float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
+ float * s = (float *) (src + i * rctx->src0_row_size_aligned);
+
+ hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache);
+
+ // fill the remain channels with data from src tensor
+ if (rctx->n_dims < ne0) {
+ hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
+ }
+ }
+}
+
+static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
+ uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
+ #pragma unroll(4)
+ for (uint32_t i = 0; i < nr; i++) {
+ float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
+ float * s = (float *) (src + i * rctx->src0_row_size_aligned);
- src0_curr += 2 * VLEN;
- theta_curr += 2 * VLEN;
- dst_curr += 2 * VLEN;
+ hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache);
+
+ // fill the remain channels with data from src tensor
+ if (rctx->n_dims < ne0) {
+ hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
+ }
}
}
-static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
- const uint32_t ir0,
- const uint32_t ir1,
- int nth,
- int ith,
- const int opt_path) {
- struct htp_ops_context * octx = rope_ctx->octx;
+static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_rope_context * rctx = (struct htp_rope_context *) data;
+ struct htp_ops_context * octx = rctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;
- const int32_t mode = rope_ctx->mode;
- const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
-
htp_rope_preamble;
- const int32_t * pos = (const int32_t *) src1->data;
+ const uint32_t src0_nrows = rctx->src0_nrows;
+ const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread;
- float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
- const float * freq_factors = NULL;
- if (src2 != NULL) {
- freq_factors = (const float *) src2->data;
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
}
- const uint32_t i1_end = MIN(ir1, ne1);
- const int32_t half_dims = rope_ctx->n_dims / 2;
- const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
- for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
- for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
- const int32_t p = pos[i2];
+ uint64_t tt = HAP_perf_get_qtimer_count();
- rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
- rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
+ const int32_t mode = rctx->mode;
+ const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
- for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
- const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
- float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
+ // VTCM setup
+ uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ float * theta_cache = (float *) (src0_spad_base);
+ src0_spad_base = src0_spad_base + rctx->theta_cache_offset;
+ uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
- const float * src_loc = src;
- float * dst_data_loc = dst_data;
+ dma_queue * dma_queue = octx->ctx->dma[ith];
+ const int32_t * pos = (const int32_t *) src1->data;
+ const float * freq_factors = src2->data ? (const float *) src2->data : NULL;
- if (1 == opt_path) {
- if (is_neox) {
- hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
- } else {
- hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
- }
+ uint32_t ir = 0;
+ uint32_t prev_i2 = (uint32_t) -1;
- src_loc += rope_ctx->n_dims;
- dst_data_loc += rope_ctx->n_dims;
- } else {
- for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
- const float cos_theta = wp0[i0 + 0];
- const float sin_theta = wp0[i0 + 1];
+ for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
+ for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
+ for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads
+ if (ir < src0_start_row) { ir++; i1++; continue; }
+ if (ir >= src0_end_row) goto done;
- if (is_neox) {
- const float x0 = src_loc[0];
- const float x1 = src_loc[half_dims];
+ // Rows in this block
+ const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1);
- dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
- dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
+ // Depth before prefetch
+ uint32_t dma_depth = dma_queue_depth(dma_queue);
- src_loc += 1;
- dst_data_loc += 1;
- } else {
- const float x0 = src_loc[0];
- const float x1 = src_loc[1];
+ // FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth,
+ // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
- dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
- dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
+ // Prefetch loop
+ for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) {
+ pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK);
- src_loc += 2;
- dst_data_loc += 2;
- }
- }
+ uint32_t pi1 = i1 + pr;
+ uint32_t pir = ir + pr;
+
+ // Dummy DMA transaction for sequencing (interleaving dst,src,dst,...)
+ dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0);
- src_loc += (is_neox ? half_dims : 0);
- dst_data_loc += (is_neox ? half_dims : 0);
+ const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
+ uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned;
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
+ rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
+
+ // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
}
- // TODO: use simd to speed up the remaining elements copy
- memcpy(dst_data_loc, src_loc, remain_bytes);
- }
- }
- }
-}
+ // Update theta cache
+ if (i2 != prev_i2) {
+ prev_i2 = i2;
-static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
- struct htp_ops_context * octx = rope_ctx->octx;
+ const int32_t p = pos[i2];
+ rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);
- const struct htp_tensor * src0 = &octx->src0;
- const struct htp_tensor * src1 = &octx->src1;
- struct htp_tensor * dst = &octx->dst;
+ // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
+ // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
+ }
- htp_rope_preamble;
+ // Skip DMA transactions from prev block (if any)
+ // No need to wait for these since the DMA is setup for in-order processing
+ for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
- const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+ // Compute loop
+ for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) {
+ // Number of rows to compute
+ cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK);
- const uint32_t src0_start_row = src0_nrows_per_thread * ith;
- const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+ uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src;
+ uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst;
- // no work for this thread
- if (src0_start_row >= src0_end_row) {
- return;
- }
+ // FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr,
+ // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
- uint64_t t1, t2;
- t1 = HAP_perf_get_qtimer_count();
+ if (is_neox) {
+ rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
+ } else {
+ rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
+ }
- int is_aligned = 1;
- int opt_path = 0;
- if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
- (0 == hex_is_aligned((void *) dst->data, VLEN))) {
- FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
- is_aligned = 0;
- }
- if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
- opt_path = 1;
- }
+ uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1;
+ dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr);
- rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
+ // Prefetch more rows (if any)
+ if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) {
+ uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK);
+ uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS;
+ uint32_t pir = ir + HTP_ROPE_SPAD_NROWS;
- t2 = HAP_perf_get_qtimer_count();
+ const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
+ rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
- FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
- (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
-}
+ // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
+ }
+ }
+ }
+ }
+ }
-static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
- struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
+done:
+ dma_queue_flush(dma_queue);
+ tt = HAP_perf_get_qtimer_count() - tt;
- rope_job_f32_per_thread(rope_ctx, n, i);
+ FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt));
}
static int execute_op_rope_f32(struct htp_ops_context * octx) {
const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;
- worker_callback_t op_func;
- const char * op_type = NULL;
-
- struct rope_th_ctx rope_ctx;
+ const char * op_type = "rope-f32";
switch (octx->op) {
case HTP_OP_ROPE:
- op_func = rope_job_dispatcher_f32;
- op_type = "rope-f32";
-
- init_rope_ctx(&rope_ctx, octx);
break;
default:
const uint32_t n_threads = octx->n_threads;
const size_t src0_row_size = src0->nb[1];
- const size_t src1_row_size = src0_row_size;
const size_t dst_row_size = dst->nb[1];
- // VTCM scratchpads for all tensors
- // N rows per thread, padded to HVX vector size
- octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
- octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
- octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
-
- size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
-
- if (src2->ne[0]) {
- FARF(HIGH,
- "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
- "dst-spad-size %u\n",
- op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
- src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
- dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
- } else {
- FARF(HIGH,
- "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
- op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
- src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
- octx->dst_spad.size);
- }
+ // Aligned row sizes for VTCM
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
+ const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128);
+
+ // Calculate spad sizes per thread
+ size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned;
+ size_t dst_spad_per_thread = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned;
+ size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread;
- // Make sure the reserved vtcm size is sufficient
- if (octx->ctx->vtcm_size < spad_size) {
- FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
- spad_size);
+ // Check if we fit in VTCM
+ size_t total_vtcm_needed = spad_per_thread * n_threads;
+ if (octx->ctx->vtcm_size < total_vtcm_needed) {
+ FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed);
return HTP_STATUS_VTCM_TOO_SMALL;
}
+ // Assign sizes
+ octx->src0_spad.size_per_thread = src0_spad_per_thread;
+ octx->dst_spad.size_per_thread = dst_spad_per_thread;
+ octx->src0_spad.size = n_threads * src0_spad_per_thread;
+ octx->dst_spad.size = n_threads * dst_spad_per_thread;
+ octx->src1_spad.size = 0;
+
+ // Assign pointers
octx->src0_spad.data = octx->ctx->vtcm_base;
- octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
- octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+ octx->src1_spad.data = NULL;
+ octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+
+ // Fill context
+ struct htp_rope_context rctx;
+ memset(&rctx, 0, sizeof(struct htp_rope_context));
+
+ rctx.t_start = HAP_perf_get_qtimer_count();
+
+ rctx.octx = octx;
+
+ const int32_t * op_params = &octx->op_params[0];
+ rctx.n_dims = ((const int32_t *) op_params)[1];
+ rctx.mode = ((const int32_t *) op_params)[2];
+ rctx.n_ctx_orig = ((const int32_t *) op_params)[4];
+ memcpy(&rctx.freq_base, (int32_t *) op_params + 5, sizeof(float));
+ memcpy(&rctx.freq_scale, (int32_t *) op_params + 6, sizeof(float));
+ memcpy(&rctx.ext_factor, (int32_t *) op_params + 7, sizeof(float));
+ memcpy(&rctx.attn_factor, (int32_t *) op_params + 8, sizeof(float));
+ memcpy(&rctx.beta_fast, (int32_t *) op_params + 9, sizeof(float));
+ memcpy(&rctx.beta_slow, (int32_t *) op_params + 10, sizeof(float));
+ memcpy(&rctx.sections, (int32_t *) op_params + 11, sizeof(int) * 4);
+
+ rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims);
+
+ rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims);
+
+ rctx.src0_row_size = src0_row_size;
+ rctx.dst_row_size = dst_row_size;
+ rctx.src0_row_size_aligned = src0_row_size_aligned;
+ rctx.dst_row_size_aligned = dst_row_size_aligned;
+ rctx.theta_cache_offset = theta_cache_size_aligned;
+
+ uint32_t ne0 = dst->ne[0];
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
+ rctx.src0_nrows = src0_nrows;
+
+ FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
+ rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
- uint32_t n_jobs = MIN(n_threads, src0_nrows);
- octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
- worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
+ uint32_t n_jobs = MIN(n_threads, src0_nrows);
+ rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs);
}
return err;
\
const uint32_t nr = ne01;
-static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+struct htp_set_rows_context {
+ struct htp_ops_context * octx;
+ struct fastdiv_values div_ne12;
+ struct fastdiv_values div_ne11;
+ uint32_t src0_nrows_per_thread;
+};
+
+static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
+ struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
+ struct htp_ops_context * octx = srctx->octx;
+
set_rows_preamble;
// parallelize by rows of src0
- const uint32_t dr = octx->src0_nrows_per_thread;
+ const uint32_t dr = srctx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
- const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
- const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+ const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
+ const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
}
}
}
-
- return HTP_STATUS_OK;
}
-static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {
+ struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
+ struct htp_ops_context * octx = srctx->octx;
+
set_rows_preamble;
// parallelize by rows of src0
- const uint32_t dr = octx->src0_nrows_per_thread;
+ const uint32_t dr = srctx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
- const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
- const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+ const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
+ const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
}
}
}
-
- return HTP_STATUS_OK;
-}
-
-static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
- set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
-}
-
-static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
- set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_set_rows(struct htp_ops_context * octx) {
return HTP_STATUS_OK;
}
- octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
- octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
+ struct htp_set_rows_context srctx;
+ srctx.octx = octx;
+ srctx.div_ne12 = init_fastdiv_values(ne12);
+ srctx.div_ne11 = init_fastdiv_values(ne11);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
- octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+ srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
switch(octx->dst.type) {
case HTP_TYPE_F32:
- worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs);
break;
case HTP_TYPE_F16:
- worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs);
break;
default:
return HTP_STATUS_NO_SUPPORT;
#include "hex-dma.h"
#include "hvx-utils.h"
+#include "hex-fastdiv.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
-struct softmax_th_ctx {
+struct htp_softmax_context {
bool use_f16;
bool use_src1;
uint32_t n_head;
float m0;
float m1;
+ uint32_t src0_nrows_per_thread;
+ struct fastdiv_values fastdiv_ne01;
+ struct fastdiv_values fastdiv_ne02;
+ struct fastdiv_values fastdiv_ne12; // For mask broadcasting
+ struct fastdiv_values fastdiv_ne13; // For mask broadcasting
+ size_t spad_stride;
+
struct htp_ops_context * octx;
};
-static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
+static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) {
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
- memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
+ memset(smctx, 0, sizeof(struct htp_softmax_context));
+
+ memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
+ memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
+
+ smctx->n_head = src0->ne[2];
+ smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head));
+
+ smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2);
+ smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2);
- memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
- memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
+ smctx->use_src1 = (src1->ne[0] != 0);
+ smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
- softmax_ctx->n_head = src0->ne[2];
- softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
+ smctx->octx = octx;
- softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
- softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
+ // Initialize fastdiv values
+ const uint32_t ne01 = src0->ne[1];
+ const uint32_t ne02 = src0->ne[2];
- softmax_ctx->use_src1 = (src1->ne[0] != 0);
- softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
+ if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01);
+ if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02);
- softmax_ctx->octx = octx;
+ const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1;
+ const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1;
+
+ if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12);
+ if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13);
}
static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
}
- HVX_Vector v = hvx_vec_reduce_max_f32(max_vec);
- max_vec = hvx_vec_repl4(v);
+ max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes
#pragma unroll(4)
for (int i = 0; i < step_of_1; i++) {
v_pad[i] = v3;
}
- v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
- sum_vec = hvx_vec_repl4(v);
+ sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);
return sum;
}
-static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
- struct htp_ops_context * octx = softmax_ctx->octx;
-
- const struct htp_tensor * src0 = &octx->src0;
- const struct htp_tensor * src1 = &octx->src1;
- const struct htp_tensor * dst = &octx->dst;
-
- htp_softmax_preamble3;
-
- uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
- uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
- uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1);
-
- float * wp0 = (float *) src0_spad_data;
- float * wp1 = (float *) src1_spad_data;
- float * wp2 = (float *) dst_spad_data;
-
- for (uint32_t i03 = 0; i03 < ne03; i03++) {
- for (uint32_t i02 = 0; i02 < ne02; i02++) {
- for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
- const uint32_t i11 = i01;
- const uint32_t i12 = i02 % ne12;
- const uint32_t i13 = i03 % ne13;
-
- // ALiBi
- const uint32_t h = i02; // head
-
- const float slope = (softmax_ctx->max_bias > 0.0f) ?
- h < softmax_ctx->n_head_log2 ?
- powf(softmax_ctx->m0, h + 1) :
- powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
- 1.0f;
-
- float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
- float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
-
- // broadcast the mask across rows
- __fp16 * mp_f16 = (softmax_ctx->use_src1) ?
- (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
- NULL;
- float * mp_f32 = (softmax_ctx->use_src1) ?
- (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
- NULL;
-
- if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
- hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
- (const uint8_t *) mp_f32, slope);
- } else {
- hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
- if (mp_f32) {
- if (softmax_ctx->use_f16) {
- for (int i = 0; i < ne00; ++i) {
- wp0[i] += slope * (float) mp_f16[i];
- }
- } else {
- for (int i = 0; i < ne00; ++i) {
- wp0[i] += slope * mp_f32[i];
- }
- }
- }
- }
-
- if (1 == opt_path) {
- hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
- } else {
- float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
- float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
- sum = sum > 0.0 ? (1.0 / sum) : 1;
- hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
- }
- }
- }
- }
-}
-
-static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
- struct htp_ops_context * octx = softmax_ctx->octx;
+static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
+ struct htp_softmax_context * smctx = (struct htp_softmax_context *) data;
+ struct htp_ops_context * octx = smctx->octx;
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
htp_softmax_preamble3;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
- const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+ const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
opt_path = 1;
}
- softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
+ uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride);
+ uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride);
+ uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride);
+
+ float * wp0 = (float *) src0_spad_data;
+ float * wp1 = (float *) src1_spad_data;
+ float * wp2 = (float *) dst_spad_data;
+
+ uint32_t prev_i2 = (uint32_t)-1;
+ float slope = 1.0f;
+
+ for (uint32_t r = src0_start_row; r < src0_end_row; ++r) {
+ uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01);
+ uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01);
+ uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02);
+ uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02);
+
+ // Map to original logic indices
+ // i01 = i1
+ // i02 = i2
+ // i03 = i3
+
+ const uint32_t i11 = i1;
+ // const uint32_t i12 = i2 % ne12;
+ // const uint32_t i13 = i3 % ne13;
+
+ uint32_t i12, i13;
+ if (ne12 == ne02) {
+ i12 = i2;
+ } else {
+ i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12);
+ }
+
+ if (ne13 == ne03) {
+ i13 = i3;
+ } else {
+ i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13);
+ }
+
+ // ALiBi
+ if (i2 != prev_i2) {
+ const uint32_t h = i2; // head
+
+ slope = (smctx->max_bias > 0.0f) ?
+ h < smctx->n_head_log2 ?
+ powf(smctx->m0, h + 1) :
+ powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) :
+ 1.0f;
+ prev_i2 = i2;
+ }
+
+ float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03);
+ float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3);
+
+ // broadcast the mask across rows
+ __fp16 * mp_f16 = (smctx->use_src1) ?
+ (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
+ NULL;
+ float * mp_f32 = (smctx->use_src1) ?
+ (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
+ NULL;
+
+ if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) {
+ hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale,
+ (const uint8_t *) mp_f32, slope);
+ } else {
+ hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
+ if (mp_f32) {
+ if (smctx->use_f16) {
+ for (int i = 0; i < ne00; ++i) {
+ wp0[i] += slope * (float) mp_f16[i];
+ }
+ } else {
+ for (int i = 0; i < ne00; ++i) {
+ wp0[i] += slope * mp_f32[i];
+ }
+ }
+ }
+ }
+
+ if (1 == opt_path) {
+ hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
+ } else {
+ float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
+ float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
+ sum = sum > 0.0 ? (1.0 / sum) : 1;
+ hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
+ }
+ }
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
- softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
+ smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
- struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
- softmax_job_f32_per_thread(p_softmax_ctx, n, i);
-}
-
static int execute_op_softmax_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
const struct htp_tensor * src1 = &octx->src1;
struct htp_tensor * dst = &octx->dst;
- worker_callback_t op_func;
- const char * op_type = NULL;
-
- struct softmax_th_ctx softmax_ctx;
+ struct htp_softmax_context smctx;
+ const char * op_type = "softmax-f32";
switch (octx->op) {
case HTP_OP_SOFTMAX:
- op_func = softmax_job_dispatcher_f32;
- op_type = "softmax-f32";
-
- init_softmax_ctx(&softmax_ctx, octx);
+ init_softmax_ctx(&smctx, octx);
break;
default:
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
+ // Use stride for calculating offset
+ smctx.spad_stride = hex_round_up(src0_row_size, 128);
+
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
if (src1->ne[0]) {
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
- octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
- worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
+ smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs);
}
return err;
#include "htp-msg.h"
#include "htp-ops.h"
-
#define sum_rows_preamble \
struct htp_tensor *src0 = &octx->src0;\
struct htp_tensor *dst = &octx->dst; \
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3]; \
-static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
- sum_rows_preamble;
+struct sum_rows_context {
+ const uint8_t * src_data;
+ uint8_t * dst_data;
+ uint32_t ne00;
+ size_t src_stride;
+ size_t dst_stride;
+ uint32_t rows_per_thread;
+ uint32_t total_rows;
+ bool opt_path;
+};
- const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
- const size_t src0_row_size = nb01;
- const size_t dst_row_size = nb1;
+static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) {
+ const struct sum_rows_context * smctx = (const struct sum_rows_context *) data;
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+ const uint32_t rows_per_thread = smctx->rows_per_thread;
+ const uint32_t total_rows = smctx->total_rows;
- const uint32_t src0_start_row = src0_nrows_per_thread * ith;
- const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+ const uint32_t start_row = rows_per_thread * ith;
+ const uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
- // no work for this thread
- if (src0_start_row >= src0_end_row) {
- return HTP_STATUS_OK;
+ if (start_row >= end_row) {
+ return;
}
- int opt_path = 0;
- if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
- opt_path = 1;
- }
+ const size_t src_stride = smctx->src_stride;
+ const size_t dst_stride = smctx->dst_stride;
+ const uint32_t ne00 = smctx->ne00;
+ const bool opt_path = smctx->opt_path;
- const uint8_t * restrict data_src = (const uint8_t *) src0->data;
- uint8_t * restrict data_dst = (uint8_t *) dst->data;
+ const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride));
+ float * restrict dst_th = (float *) (smctx->dst_data + (start_row * dst_stride));
- const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
- float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
+ // Calculate actual number of rows for this thread
+ const uint32_t n_rows = end_row - start_row;
- for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
- const float * restrict src_local = src_th + (ir * ne00);
+ for (uint32_t ir = 0; ir < n_rows; ir++) {
+ const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float)));
- if (ir + 1 < src0_nrows_per_thread) {
- hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
+ if (ir + 1 < n_rows) {
+ hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1);
}
- if (1 == opt_path) {
+ if (opt_path) {
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
} else {
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
}
}
-
- return HTP_STATUS_OK;
-}
-
-static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
- sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
}
int op_sum_rows(struct htp_ops_context * octx) {
const uint32_t src0_nrows = ne01 * ne02 * ne03;
uint32_t n_jobs = MIN(n_threads, src0_nrows);
- octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
- worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
+ bool opt_path = false;
+ if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
+ opt_path = true;
+ }
+
+ struct sum_rows_context smctx = {
+ .src_data = (const uint8_t *) src0->data,
+ .dst_data = (uint8_t *) dst->data,
+ .ne00 = ne00,
+ .src_stride = nb01,
+ .dst_stride = nb1,
+ .rows_per_thread = rows_per_thread,
+ .total_rows = src0_nrows,
+ .opt_path = opt_path,
+ };
+
+ worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs);
return HTP_STATUS_OK;
}
-
#include "htp-msg.h"
#include "htp-ops.h"
+struct htp_unary_context {
+ struct htp_ops_context * octx;
+
+ // Precomputed values
+ const uint8_t * data_src0;
+ uint8_t * data_dst;
+
+ size_t src0_row_size;
+ size_t dst_row_size;
+
+ size_t src0_row_size_aligned;
+ size_t dst_row_size_aligned;
+
+ size_t src0_spad_half_size;
+ size_t dst_spad_half_size;
+
+ uint32_t block;
+ uint32_t src0_nrows;
+ uint32_t src0_nrows_per_thread;
+ uint32_t nc;
+};
+
#define htp_unary_preamble \
const uint32_t ne00 = src->ne[0]; \
const uint32_t ne01 = src->ne[1]; \
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
}
- HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
- sum_v = hvx_vec_repl4(reduced_sum);
+ sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
}
}
-static void scale_htp_f32(const float * restrict src,
- float * restrict dst,
- uint8_t * restrict spad,
- const uint32_t num_rows,
- const uint32_t row_elems,
- const size_t row_size,
- int32_t * op_params,
- int opt_path) {
+static void scale_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params) {
float scale = 0.f;
float bias = 0.f;
memcpy(&scale, &op_params[0], sizeof(float));
memcpy(&bias, &op_params[1], sizeof(float));
for (uint32_t ir = 0; ir < num_rows; ir++) {
- const float * restrict src_local = src + (ir * row_elems);
- float * restrict dst_local = dst + (ir * row_elems);
-
- if (ir + 1 < num_rows) {
- hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
- }
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
- hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
+ hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
}
}
-static void rms_norm_htp_f32(const float * restrict src,
- float * restrict dst,
- uint8_t * restrict spad,
- const uint32_t num_rows,
- const uint32_t row_elems,
- const size_t row_size,
- int32_t * op_params,
- int opt_path) {
+static void rms_norm_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params) {
float epsilon = 0.f;
memcpy(&epsilon, op_params, sizeof(float));
for (uint32_t ir = 0; ir < num_rows; ir++) {
- const float * restrict src_local = src + (ir * row_elems);
- float * restrict dst_local = dst + (ir * row_elems);
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
- if (ir + 1 < num_rows) {
- hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
- }
-
- if (1 == opt_path) {
- hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
- } else {
- float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
-
- const float mean = sum / row_elems;
- const float scale = 1.0f / sqrtf(mean + epsilon);
-
- hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
- }
+ hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
}
}
-static void sqr_htp_f32(const float * restrict src,
- float * restrict dst,
- uint8_t * restrict spad,
- const uint32_t num_rows,
- const uint32_t row_elems,
- const size_t row_size,
- int32_t * op_params,
- int opt_path) {
+static void sqr_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
- const float * restrict src_local = src + (ir * row_elems);
- float * restrict dst_local = dst + (ir * row_elems);
-
- if (ir + 1 < num_rows) {
- hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
- }
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
- if (1 == opt_path) {
- hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
- } else {
- hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
- }
+ hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
-static void sqrt_htp_f32(const float * restrict src,
- float * restrict dst,
- uint8_t * restrict spad,
- const uint32_t num_rows,
- const uint32_t row_elems,
- const size_t row_size,
- int32_t * op_params,
- int opt_path) {
+static void sqrt_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
- const float * restrict src_local = src + (ir * row_elems);
- float * restrict dst_local = dst + (ir * row_elems);
+ const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
+ uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
- if (ir + 1 < num_rows) {
- hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
- }
-
- if (1 == opt_path) {
- hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
- } else {
- hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
- }
+ hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
-static void unary_job_f32_per_thread(const struct htp_tensor * src,
- struct htp_tensor * dst,
- uint8_t * spad,
- int htp_op,
- int32_t * op_params,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread) {
+static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
+ const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
+ struct htp_ops_context * octx = uctx->octx;
+ const struct htp_tensor * src = &octx->src0;
+ const struct htp_tensor * dst = &octx->dst;
+
htp_unary_preamble;
- const size_t src0_row_size = nb01;
- const size_t dst_row_size = nb1;
+ int htp_op = octx->op;
+ int32_t * op_params = octx->op_params;
+ uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
- const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
+ const size_t src0_row_size = uctx->src0_row_size;
+ const size_t dst_row_size = uctx->dst_row_size;
+ const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
+ const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
+
+ const uint32_t src0_nrows = uctx->src0_nrows;
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
- int is_aligned = 1;
- int opt_path = 0;
- if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
- is_aligned = 0;
- }
- if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
- opt_path = 1;
+ const uint8_t * restrict data_src = uctx->data_src0;
+ uint8_t * restrict data_dst = uctx->data_dst;
+
+ uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
+ uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
+
+ size_t src0_spad_half_size = uctx->src0_spad_half_size;
+ size_t dst_spad_half_size = uctx->dst_spad_half_size;
+
+ const int BLOCK = uctx->block;
+ if (BLOCK == 0) {
+ FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
+ octx->src0_spad.size_per_thread, src0_row_size_aligned);
+ return;
}
- const uint8_t * restrict data_src = (const uint8_t *) src->data;
- uint8_t * restrict data_dst = (uint8_t *) dst->data;
+ dma_queue * dma_queue = octx->ctx->dma[ith];
- const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
- float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
- uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01);
+ for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
- switch (htp_op) {
- case HTP_OP_RMS_NORM:
- rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
- break;
- case HTP_OP_SCALE:
- scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
- break;
- case HTP_OP_SQR:
- sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
- break;
- case HTP_OP_SQRT:
- sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
- break;
+ // Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
+ dma_queue_push_vtcm_to_ddr(dma_queue,
+ dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
+ dst_row_size, dst_row_size_aligned, 0);
- default:
- break;
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, block_size);
}
+ for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+ const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
+
+ float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
+ float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
+
+ // Process block in VTCM
+ switch (htp_op) {
+ case HTP_OP_RMS_NORM:
+ rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+ break;
+ case HTP_OP_SCALE:
+ scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+ break;
+ case HTP_OP_SQR:
+ sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+ break;
+ case HTP_OP_SQRT:
+ sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
+ break;
+ default:
+ break;
+ }
+
+ dma_queue_push_vtcm_to_ddr(dma_queue,
+ dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
+ dst_row_size, dst_row_size_aligned, block_size);
+
+ // prefetch N+2 loop iteration if any
+ const uint32_t pref_block = (ir + BLOCK * 2);
+ if (pref_block < src0_end_row) {
+ const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
+ dma_queue_push_ddr_to_vtcm(dma_queue,
+ dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
+ src0_row_size_aligned, src0_row_size, pref_block_size);
+ }
+ }
+
+ dma_queue_flush(dma_queue);
+
t2 = HAP_perf_get_qtimer_count();
- FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
+ FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0],
src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
- struct htp_ops_context * octx = (struct htp_ops_context *) data;
-
- unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
- octx->src0_nrows_per_thread);
-}
-
static int execute_op_unary_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
const struct htp_tensor * src0 = &octx->src0;
struct htp_tensor * dst = &octx->dst;
- worker_callback_t unary_op_func;
- const char * op_type = NULL;
+ const char * op_type = NULL;
switch (octx->op) {
case HTP_OP_RMS_NORM:
- unary_op_func = unary_job_dispatcher_f32;
- op_type = "rmsnorm-f32";
+ op_type = "rmsnorm-f32";
break;
case HTP_OP_SCALE:
- unary_op_func = unary_job_dispatcher_f32;
- op_type = "scale-f32";
+ op_type = "scale-f32";
break;
case HTP_OP_SQR:
- unary_op_func = unary_job_dispatcher_f32;
- op_type = "sqr-f32";
+ op_type = "sqr-f32";
break;
case HTP_OP_SQRT:
- unary_op_func = unary_job_dispatcher_f32;
- op_type = "sqrt-f32";
+ op_type = "sqrt-f32";
break;
default:
const size_t src0_row_size = src0->nb[1];
const size_t dst_row_size = dst->nb[1];
- // VTCM scratchpads for all tensors
- octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
- octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
+ const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
+ const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
- size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
+ // VTCM scratchpads for all tensors
+ // N rows per thread, padded to HVX vector size
+ // Double buffering requires 2x size per buffer
- FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
- src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
- octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+ size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
+ size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);
// Make sure the reserved vtcm size is sufficient
- if (octx->ctx->vtcm_size < spad_size) {
+ if (vtcm_row_per_thread == 0) {
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
- spad_size);
+ spad_size_per_row * n_threads);
return HTP_STATUS_VTCM_TOO_SMALL;
}
+ octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2;
+ octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2;
+
+ octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
+ octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
+
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+ FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
+
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
uint32_t n_jobs = MIN(n_threads, src0_nrows);
- octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
+ struct htp_unary_context uctx = {
+ .octx = octx,
+ .src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs,
+ .src0_nrows = src0_nrows,
+
+ .data_src0 = (const uint8_t *)src0->data,
+ .data_dst = (uint8_t *)dst->data,
+
+ .src0_row_size = src0_row_size,
+ .dst_row_size = dst_row_size,
+
+ .src0_row_size_aligned = src0_row_size_aligned,
+ .dst_row_size_aligned = dst_row_size_aligned,
+
+ .src0_spad_half_size = octx->src0_spad.size_per_thread / 2,
+ .dst_spad_half_size = octx->dst_spad.size_per_thread / 2,
+
+ .block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
+ .nc = src0->ne[0],
+ };
- worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
+ worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs);
}
return err;