return true;
}
+static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0];
+ const struct ggml_tensor * src1 = op->src[1];
+ const struct ggml_tensor * src2 = op->src[2];
+ const struct ggml_tensor * src3 = op->src[3];
+ const struct ggml_tensor * src4 = op->src[4];
+ const struct ggml_tensor * dst = op;
+
+ // Check for F16 support only as requested
+ if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {
+ return false;
+ }
+
+ if (src3 && src3->type != GGML_TYPE_F16) { // mask
+ return false;
+ }
+
+ if (src4 && src4->type != GGML_TYPE_F32) { // sinks
+ return false;
+ }
+
+ // For now we support F32 or F16 output as htp backend often converts output on the fly if needed,
+ // but the op implementation writes to F16 or F32.
+ // Let's assume dst can be F32 or F16.
+ if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
+ return false;
+ }
+
+ return opt_experimental;
+}
+
static bool hex_supported_src0_type(ggml_type t) {
return t == GGML_TYPE_F32;
}
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
- if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
+ if (dst->type != GGML_TYPE_F32) {
return false;
}
- // TODO: add support for non-cont tensors
- if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
+ if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
return false;
}
return false; // typically the lm-head which would be too large for VTCM
}
- // if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false;
if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
return false;
}
}
break;
- case GGML_TYPE_F16:
- if (!opt_experimental) {
- return false;
- }
- break;
-
default:
return false;
}
- // TODO: add support for non-cont tensors
- if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
- return false;
- }
-
return true;
}
return true;
}
+static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0]; // values
+ const struct ggml_tensor * src1 = op->src[1]; // indices
+ const struct ggml_tensor * dst = op;
+
+ if (src0->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
+ return false;
+ }
+
+ if (dst->type != GGML_TYPE_F16) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
+ const struct ggml_tensor * src0 = op->src[0]; // values
+ const struct ggml_tensor * src1 = op->src[1]; // indices
+ const struct ggml_tensor * dst = op;
+
+ if (src0->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
+ return false;
+ }
+
+ if (dst->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ return true;
+}
+
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const int32_t * op_params = &op->op_params[0];
d->offset = (uint8_t *) t->data - buf->base;
d->size = ggml_nbytes(t);
+ if (!d->size) {
+ // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty
+ d->size = 64;
+ }
+
switch (type) {
case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:
// Flush CPU
return n_bufs;
}
+static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ req->op = HTP_OP_GET_ROWS;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
template <bool _is_src0_constant>
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
switch (t->op) {
return n_bufs;
}
+static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ req->op = HTP_OP_SET_ROWS;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
supported = true;
break;
+ case GGML_OP_SCALE:
+ req->op = HTP_OP_SCALE;
+ supported = true;
+ break;
+
case GGML_OP_UNARY:
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
req->op = HTP_OP_UNARY_SILU;
return n_bufs;
}
+static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
+ memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
+ req->op = HTP_OP_FLASH_ATTN_EXT;
+
+ size_t n_bufs = 0;
+ n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
+ n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
+
+ return n_bufs;
+}
+
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
auto sess = static_cast<ggml_hexagon_session *>(backend->context);
return sess->name.c_str();
ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
break;
case GGML_OP_RMS_NORM:
+ case GGML_OP_SCALE:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break;
case GGML_OP_UNARY:
ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);
break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);
+ break;
+
+ case GGML_OP_SET_ROWS:
+ ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);
+ break;
+
+ case GGML_OP_GET_ROWS:
+ ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
+ break;
+
default:
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
}
break;
case GGML_OP_RMS_NORM:
+ case GGML_OP_SCALE:
supp = ggml_hexagon_supported_unary(sess, op);
break;
supp = ggml_hexagon_supported_rope(sess, op);
break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ supp = ggml_hexagon_supported_flash_attn_ext(sess, op);
+ break;
+
+ case GGML_OP_SET_ROWS:
+ supp = ggml_hexagon_supported_set_rows(sess, op);
+ break;
+
+ case GGML_OP_GET_ROWS:
+ supp = ggml_hexagon_supported_get_rows(sess, op);
+ break;
+
default:
break;
}
softmax-ops.c
act-ops.c
rope-ops.c
+ flash-attn-ops.c
+ set-rows-ops.c
+ get-rows-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
--- /dev/null
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+# define FARF_HIGH 1
+#endif
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+// Dot product of FP32 and FP16 vectors, accumulating to float
+static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
+ const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ if (nloe) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
+
+ // Zero-out unused elements
+ // Note that we need to clear both x and y because they may contain NANs
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ x_hf = Q6_V_vand_QV(bmask, x_hf);
+ y_hf = Q6_V_vand_QV(bmask, y_hf);
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+
+ hvx_vec_store_u(r, 4, rsum);
+}
+
+// Dot product of two F16 vectors, accumulating to float
+static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
+ const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ HVX_Vector y_hf = vy[i];
+ HVX_Vector x_hf = vx[i];
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ if (nloe) {
+ HVX_Vector y_hf = vy[i];
+
+ // Load x (fp16) and zero-out unused elements
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+ hvx_vec_store_u(r, 4, rsum);
+}
+
+// MAD: y (F32) += x (F16) * v (float)
+static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
+ const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
+ HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ HVX_Vector S = hvx_vec_splat_fp16(s);
+
+ uint32_t i = 0;
+ #pragma unroll(4)
+ for (i = 0; i < nvec; ++i) {
+ // Multiply x * s -> pair of F32 vectors
+ HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
+ ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
+ ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
+ }
+
+ if (nloe) {
+ HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
+
+ HVX_Vector xs = Q6_V_lo_W(xs_p);
+ i = 2 * i; // index for ptr_y
+
+ if (nloe >= 32) {
+ ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
+ nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
+ }
+
+ if (nloe) {
+ HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
+ hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
+ }
+ }
+}
+
+#define FLASH_ATTN_BLOCK_SIZE 128
+
+static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
+ const struct htp_tensor * q = &octx->src0;
+ const struct htp_tensor * k = &octx->src1;
+ const struct htp_tensor * v = &octx->src2;
+ const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
+ const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
+ struct htp_tensor * dst = &octx->dst;
+
+ const uint32_t neq0 = q->ne[0];
+ const uint32_t neq1 = q->ne[1];
+ const uint32_t neq2 = q->ne[2];
+ const uint32_t neq3 = q->ne[3];
+
+ const uint32_t nek0 = k->ne[0];
+ const uint32_t nek1 = k->ne[1];
+ const uint32_t nek2 = k->ne[2];
+ const uint32_t nek3 = k->ne[3];
+
+ const uint32_t nev0 = v->ne[0];
+ const uint32_t nev1 = v->ne[1];
+ const uint32_t nev2 = v->ne[2];
+ const uint32_t nev3 = v->ne[3];
+
+ const uint32_t nbq1 = q->nb[1];
+ const uint32_t nbq2 = q->nb[2];
+ const uint32_t nbq3 = q->nb[3];
+
+ const uint32_t nbk1 = k->nb[1];
+ const uint32_t nbk2 = k->nb[2];
+ const uint32_t nbk3 = k->nb[3];
+
+ const uint32_t nbv1 = v->nb[1];
+ const uint32_t nbv2 = v->nb[2];
+ const uint32_t nbv3 = v->nb[3];
+
+ const uint32_t ne1 = dst->ne[1];
+ const uint32_t ne2 = dst->ne[2];
+ const uint32_t ne3 = dst->ne[3];
+
+ const uint32_t nb1 = dst->nb[1];
+ const uint32_t nb2 = dst->nb[2];
+ const uint32_t nb3 = dst->nb[3];
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+ float logit_softcap = 0.0f;
+
+ memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
+ memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
+
+ if (logit_softcap != 0) {
+ scale /= logit_softcap;
+ }
+
+ // total rows in q
+ const uint32_t nr = neq1*neq2*neq3;
+
+ const uint32_t dr = (nr + nth - 1) / nth;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = MIN(ir0 + dr, nr);
+
+ if (ir0 >= ir1) return;
+
+ dma_queue * dma = octx->ctx->dma[ith];
+
+ const uint32_t DK = nek0;
+ const uint32_t DV = nev0;
+
+ const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
+ const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
+
+ const size_t size_k_row = DK * sizeof(__fp16);
+ const size_t size_v_row = DV * sizeof(__fp16);
+ const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
+
+ const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
+ const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
+
+ const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
+
+ // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
+ uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
+ uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
+ uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
+ uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
+ uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
+
+ const uint32_t n_head = neq2;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ for (uint32_t ir = ir0; ir < ir1; ++ir) {
+ const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
+ const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
+ const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
+
+ const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
+ const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
+
+ const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
+ const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
+
+ // Fetch Q row
+ const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
+ dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
+
+ const uint32_t h = iq2; // head index
+ const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
+
+ float S = 0.0f; // sum
+ float M = -INFINITY; // maximum KQ value
+
+ // Clear accumulator
+ float * VKQ32 = (float *) spad_a;
+ memset(VKQ32, 0, DV * sizeof(float));
+
+ const __fp16 * mp_base = NULL;
+ if (mask) {
+ const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
+ const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
+ mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
+ }
+
+ const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
+
+ // Prefetch first two blocks
+ for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
+ const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
+ const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
+
+ // K
+ const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
+ uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
+ dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
+
+ // V
+ const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
+ uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
+ dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
+
+ // Mask
+ if (mask) {
+ const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
+ uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
+ // Mask is 1D contiguous for this row
+ dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
+ }
+ }
+
+ const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
+
+ for (uint32_t ib = 0; ib < n_blocks; ++ib) {
+ const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
+ const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
+
+ // Wait for DMA
+ uint8_t * k_base = dma_queue_pop(dma).dst; // K
+ uint8_t * v_base = dma_queue_pop(dma).dst; // V
+ __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
+
+ // Inner loop processing the block from VTCM
+ uint32_t ic = 0;
+
+ // Process in blocks of 32 (VLEN_FP32)
+ for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
+ // 1. Compute scores
+ float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
+ for (int j = 0; j < VLEN_FP32; ++j) {
+ const uint32_t cur_ic = ic + j;
+ const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
+ if (q->type == HTP_TYPE_F32) {
+ hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
+ } else {
+ hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
+ }
+ }
+
+ HVX_Vector scores = *(HVX_Vector *) scores_arr;
+
+ // 2. Softcap
+ if (logit_softcap != 0.0f) {
+ scores = hvx_vec_tanh_fp32(scores);
+ scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
+ scores = Q6_Vsf_equals_Vqf32(scores);
+ }
+
+ // 3. Mask
+ if (mask) {
+ const __fp16 * mp = m_base + ic;
+ HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
+
+ HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
+ HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
+
+ HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
+
+ HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
+ HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
+ scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
+ scores = Q6_Vsf_equals_Vqf32(scores);
+ }
+
+ // 4. Online Softmax Update
+ HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
+ float m_block = hvx_vec_get_fp32(v_max);
+
+ float M_old = M;
+ float M_new = (m_block > M) ? m_block : M;
+ M = M_new;
+
+ float ms = expf(M_old - M_new);
+
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+ S = S * ms;
+
+ HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
+ HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
+ HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
+
+ HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
+ float p_sum = hvx_vec_get_fp32(p_sum_vec);
+ S += p_sum;
+
+ // 5. Accumulate V
+ float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
+ *(HVX_Vector*)p_arr = P;
+
+ for (int j = 0; j < VLEN_FP32; ++j) {
+ const uint32_t cur_ic = ic + j;
+ const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
+ hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
+ }
+ }
+
+ // Leftover
+ for (; ic < current_block_size; ++ic) {
+ float s_val;
+ const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
+
+ if (q->type == HTP_TYPE_F32) {
+ hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
+ } else {
+ hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
+ }
+
+ if (logit_softcap != 0.0f) {
+ s_val = logit_softcap * tanhf(s_val);
+ }
+
+ if (mask) {
+ const float m_val = m_base[ic];
+ s_val += slope * m_val;
+ }
+
+ const float Mold = M;
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (s_val > M) {
+ M = s_val;
+ ms = expf(Mold - M);
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+ } else {
+ vs = expf(s_val - M);
+ }
+
+ const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
+
+ hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
+
+ S = S * ms + vs;
+ }
+
+ // Issue DMA for next+1 block (if exists)
+ if (ib + 2 < n_blocks) {
+ const uint32_t next_ib = ib + 2;
+ const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
+ const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
+
+ // K
+ const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
+ dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
+
+ // V
+ const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
+ dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
+
+ // Mask
+ if (mask) {
+ const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
+ dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
+ }
+ }
+ }
+
+ // sinks
+ if (sinks) {
+ const float s = ((float *)((char *) sinks->data))[h];
+
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (s > M) {
+ ms = expf(M - s);
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+ } else {
+ vs = expf(s - M);
+ }
+
+ S = S * ms + vs;
+ }
+
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
+
+ // Store result
+ // dst indices
+ const int i1 = iq1;
+ const int i2 = iq2;
+ const int i3 = iq3;
+
+ // dst is permuted
+ uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
+
+ if (dst->type == HTP_TYPE_F32) {
+ hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
+ } else if (dst->type == HTP_TYPE_F16) {
+ hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
+ }
+ }
+}
+
+static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+ flash_attn_ext_f16_thread(octx, i, n);
+}
+
+int op_flash_attn_ext(struct htp_ops_context * octx) {
+ const struct htp_tensor * q = &octx->src0;
+ const struct htp_tensor * k = &octx->src1;
+ const struct htp_tensor * v = &octx->src2;
+ const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
+ struct htp_tensor * dst = &octx->dst;
+
+ // Check support
+ if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
+ k->type != HTP_TYPE_F16 ||
+ v->type != HTP_TYPE_F16) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
+ octx->src0_div1 = init_fastdiv_values(q->ne[1]);
+
+ octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
+ octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
+ octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
+ octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
+
+ if (mask) {
+ octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
+ octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
+ }
+
+ size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
+ size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
+ size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
+
+ size_t size_q_block = size_q_row_padded * 1; // single row for now
+ size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
+
+ size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
+
+ octx->src0_spad.size_per_thread = size_q_block * 1;
+ octx->src1_spad.size_per_thread = size_k_block * 2;
+ octx->src2_spad.size_per_thread = size_v_block * 2;
+ octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
+ octx->dst_spad.size_per_thread = size_vkq_acc;
+
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
+ octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
+ octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
+
+ size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
+
+ if (octx->ctx->vtcm_size < total_spad) {
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+ octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+ octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
+ octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
+
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
+ }
+
+ return HTP_STATUS_OK;
+}
--- /dev/null
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+# define FARF_HIGH 1
+#endif
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+#define get_rows_preamble \
+ const uint32_t ne00 = octx->src0.ne[0]; \
+ const uint32_t ne01 = octx->src0.ne[1]; \
+ const uint32_t ne02 = octx->src0.ne[2]; \
+ const uint32_t ne03 = octx->src0.ne[3]; \
+ \
+ const uint32_t ne10 = octx->src1.ne[0]; \
+ const uint32_t ne11 = octx->src1.ne[1]; \
+ const uint32_t ne12 = octx->src1.ne[2]; \
+ \
+ const uint32_t nb01 = octx->src0.nb[1]; \
+ const uint32_t nb02 = octx->src0.nb[2]; \
+ const uint32_t nb03 = octx->src0.nb[3]; \
+ \
+ const uint32_t nb10 = octx->src1.nb[0]; \
+ const uint32_t nb11 = octx->src1.nb[1]; \
+ const uint32_t nb12 = octx->src1.nb[2]; \
+ \
+ const uint32_t nb1 = octx->dst.nb[1]; \
+ const uint32_t nb2 = octx->dst.nb[2]; \
+ const uint32_t nb3 = octx->dst.nb[3]; \
+ \
+ const uint32_t nr = ne10 * ne11 * ne12;
+
+static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+ get_rows_preamble;
+
+ // parallelize by src1 elements (which correspond to dst rows)
+ const uint32_t dr = octx->src1_nrows_per_thread;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
+
+ const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
+
+ for (uint32_t i = ir0; i < ir1; ++i) {
+ const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
+ const uint32_t rem = i - i12 * ne11 * ne10;
+ const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
+ const uint32_t i10 = rem - i11 * ne10;
+
+ const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
+
+ uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
+
+ if (i01 >= ne01) {
+ // invalid index, skip for now to avoid crash
+ continue;
+ }
+
+ const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
+ const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
+ hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
+ }
+
+ return HTP_STATUS_OK;
+}
+
+static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
+ get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
+}
+
+int op_get_rows(struct htp_ops_context * octx) {
+ get_rows_preamble;
+
+ if (octx->src0.type != HTP_TYPE_F32) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->dst.type != HTP_TYPE_F32) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+ return HTP_STATUS_OK;
+ }
+
+ octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
+ octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
+
+ const uint32_t n_jobs = MIN(nr, octx->n_threads);
+ octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+
+ worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
+ return HTP_STATUS_OK;
+}
#define HTP_MAX_NTHREADS 10
-// FIXME: move these into matmul-ops
-#define HTP_SPAD_SRC0_NROWS 16
-#define HTP_SPAD_SRC1_NROWS 16
-#define HTP_SPAD_DST_NROWS 2
-
// Main context for htp DSP backend
struct htp_context {
dspqueue_t queue;
HTP_TYPE_F16 = 1,
HTP_TYPE_Q4_0 = 2,
HTP_TYPE_Q8_0 = 8,
+ HTP_TYPE_I32 = 26,
+ HTP_TYPE_I64 = 27,
HTP_TYPE_MXFP4 = 39,
HTP_TYPE_COUNT
};
HTP_OP_SOFTMAX = 11,
HTP_OP_ADD_ID = 12,
HTP_OP_ROPE = 13,
+ HTP_OP_FLASH_ATTN_EXT = 14,
+ HTP_OP_SET_ROWS = 15,
+ HTP_OP_SCALE = 16,
+ HTP_OP_GET_ROWS = 17,
INVALID
};
struct htp_tensor src0; // Input0 tensor
struct htp_tensor src1; // Input1 tensor
struct htp_tensor src2; // Input2 tensor
+ struct htp_tensor src3; // Input3 tensor
+ struct htp_tensor src4; // Input4 tensor
struct htp_tensor dst; // Output tensor
// should be multiple of 64 bytes (cacheline)
};
#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
-#define HTP_MAX_PACKET_BUFFERS 4
+#define HTP_MAX_PACKET_BUFFERS 8
#endif /* HTP_MSG_H */
struct htp_spad {
uint8_t * data;
+ size_t stride;
size_t size;
size_t size_per_thread;
};
struct htp_tensor src0;
struct htp_tensor src1;
struct htp_tensor src2;
+ struct htp_tensor src3;
+ struct htp_tensor src4;
struct htp_tensor dst;
struct htp_spad src0_spad;
struct htp_spad src1_spad;
struct htp_spad src2_spad;
+ struct htp_spad src3_spad;
struct htp_spad dst_spad;
worker_pool_context_t * wpool; // worker pool
struct fastdiv_values src1_div3; // fastdiv values for ne3
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
+ struct fastdiv_values src3_div1; // fastdiv values for ne1
+ struct fastdiv_values src3_div2; // fastdiv values for ne2
+ struct fastdiv_values src3_div3; // fastdiv values for ne3
+ struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
+
+ struct fastdiv_values broadcast_rk2;
+ struct fastdiv_values broadcast_rk3;
+ struct fastdiv_values broadcast_rv2;
+ struct fastdiv_values broadcast_rv3;
+
+ struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
+ struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
+ struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
+ struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
+
+ struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
+ struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
+
+ struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
+ struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
+
uint32_t flags;
};
int op_softmax(struct htp_ops_context * octx);
int op_add_id(struct htp_ops_context * octx);
int op_rope(struct htp_ops_context * octx);
+int op_flash_attn_ext(struct htp_ops_context * octx);
+int op_set_rows(struct htp_ops_context * octx);
+int op_get_rows(struct htp_ops_context * octx);
#endif /* HTP_OPS_H */
return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
}
-void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) {
- int left_over = num_elems & (VLEN_FP32 - 1);
- int num_elems_whole = num_elems - left_over;
-
- int unaligned_addr = 0;
- int unaligned_loop = 0;
- if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
- FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n");
- unaligned_addr = 1;
- }
-
- if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
- unaligned_loop = 1;
- FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n");
- }
-
- HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
-
- if (0 == unaligned_loop) {
- HVX_Vector * vec_in1 = (HVX_Vector *) src;
- HVX_Vector * vec_out = (HVX_Vector *) dst;
-
- #pragma unroll(4)
- for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
- HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec);
- *vec_out++ = Q6_Vsf_equals_Vqf32(v);
- }
- } else {
- #pragma unroll(4)
- for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
- HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
-
- HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
-
- *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
- }
- }
-
- if (left_over > 0) {
- const float * srcf = (const float *) src + num_elems_whole;
- float * dstf = (float *) dst + num_elems_whole;
-
- HVX_Vector in = *(HVX_UVector *) srcf;
-
- HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
- hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
- }
-}
-
float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
int left_over = num_elems & (VLEN_FP32 - 1);
int num_elems_whole = num_elems - left_over;
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
}
}
+
+
}
#endif
-static inline HVX_Vector hvx_vec_splat_fp32(float i) {
+static inline HVX_Vector hvx_vec_splat_fp32(float v) {
union {
- float f;
- int32_t i;
- } fp32 = { .f = i };
+ float f;
+ uint32_t i;
+ } fp32 = { .f = v };
return Q6_V_vsplat_R(fp32.i);
}
+static inline HVX_Vector hvx_vec_splat_fp16(float v) {
+ union {
+ __fp16 f;
+ uint16_t i;
+ } fp16 = { .f = v };
+
+ return Q6_Vh_vsplat_R(fp16.i);
+}
+
static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) {
// Rotate as needed.
v = Q6_V_vlalign_VVR(v, v, (size_t) addr);
}
}
+// copy n fp32 elements : source is unaligned, destination unaligned
+static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ HVX_UVector * restrict vdst = (HVX_UVector *) dst;
+ HVX_UVector * restrict vsrc = (HVX_UVector *) src;
+
+ assert((unsigned long) dst % 128 == 0);
+
+ uint32_t nvec = n / 32;
+ uint32_t nloe = n % 32;
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (; i < nvec; i++) {
+ HVX_Vector v = vsrc[i];
+ vdst[i] = v;
+ }
+
+ if (nloe) {
+ HVX_Vector v = vsrc[i];
+ hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
+ }
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
+static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
+ HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+
+ uint32_t nvec = n / 64;
+ uint32_t nloe = n % 64;
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (; i < nvec; i++) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+ HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+ HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+ vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
+ }
+
+ if (nloe) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+ HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+ HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+ hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
+ }
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
+static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
+ HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+
+ uint32_t nvec = n / 64;
+ uint32_t nloe = n % 64;
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (; i < nvec; i++) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+ HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+ HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+ vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
+ }
+
+ if (nloe) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+ HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+ HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+ hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
+ }
+}
+
+// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
+static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
+ HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16
+ HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+
+ uint32_t nvec = n / 64;
+ uint32_t nloe = n % 64;
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (; i < nvec; i++) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+ HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+ HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+ vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
+ }
+
+ if (nloe) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
+ HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
+ HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
+ hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
+ }
+}
+
// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned
static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) {
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
return right_off <= chunk_size;
}
-
-
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
HVX_VectorAlias u = { .v = v };
}
static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
-#if __HTP_ARCH__ > 75
+#if __HVX_ARCH__ > 75
return Q6_Vsf_vfneg_Vsf(v);
#else
// neg by setting the fp32 sign bit
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
return Q6_V_vxor_VV(v, mask);
-#endif // __HTP_ARCH__ > 75
+#endif // __HVX_ARCH__ > 75
}
// ====================================================
return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
}
+static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) {
+ // tanh(x) = 2 * sigmoid(2x) - 1
+ HVX_Vector two = hvx_vec_splat_fp32(2.0f);
+ HVX_Vector one = hvx_vec_splat_fp32(1.0f);
+ HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two);
+
+ static const float kMinExp = -87.f; // 0
+ static const float kMaxExp = 87.f; // 1
+ HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
+ HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
+
+ HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
+
+ HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
+ res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
+ return Q6_Vsf_equals_Vqf32(res);
+}
+
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
int step_of_1 = num_elems >> 5;
int remaining = num_elems - step_of_1 * VLEN_FP32;
}
}
+static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+ int nvec = n / VLEN_FP32;
+ int nloe = n % VLEN_FP32;
+
+ HVX_Vector vs = hvx_vec_splat_fp32(scale);
+
+ HVX_Vector * vsrc = (HVX_Vector *) src;
+ HVX_Vector * vdst = (HVX_Vector *) dst;
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; ++i) {
+ HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
+ vdst[i] = Q6_Vsf_equals_Vqf32(v);
+ }
+
+ if (nloe) {
+ HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
+ hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
+ }
+}
+
+static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+ int nvec = n / VLEN_FP32;
+ int nloe = n % VLEN_FP32;
+
+ HVX_Vector vs = hvx_vec_splat_fp32(scale);
+
+ HVX_UVector * vsrc = (HVX_UVector *) src;
+ HVX_UVector * vdst = (HVX_UVector *) dst;
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; ++i) {
+ HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
+ vdst[i] = Q6_Vsf_equals_Vqf32(v);
+ }
+
+ if (nloe) {
+ HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
+ hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
+ }
+}
+
+static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
+ if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
+ hvx_scale_f32_aa(dst, src, n, scale);
+ } else {
+ hvx_scale_f32_uu(dst, src, n, scale);
+ }
+}
+
+static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+ int nvec = n / VLEN_FP32;
+ int nloe = n % VLEN_FP32;
+
+ HVX_Vector vs = hvx_vec_splat_fp32(scale);
+ HVX_Vector vo = hvx_vec_splat_fp32(offset);
+
+ HVX_Vector * vsrc = (HVX_Vector *) src;
+ HVX_Vector * vdst = (HVX_Vector *) dst;
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; ++i) {
+ HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
+ vdst[i] = Q6_Vsf_equals_Vqf32(v);
+ }
+
+ if (nloe) {
+ HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
+ hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
+ }
+}
+
+static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+ int nvec = n / VLEN_FP32;
+ int nloe = n % VLEN_FP32;
+
+ HVX_Vector vs = hvx_vec_splat_fp32(scale);
+ HVX_Vector vo = hvx_vec_splat_fp32(offset);
+
+ HVX_UVector * vsrc = (HVX_UVector *) src;
+ HVX_UVector * vdst = (HVX_UVector *) dst;
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; ++i) {
+ HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
+ vdst[i] = Q6_Vsf_equals_Vqf32(v);
+ }
+
+ if (nloe) {
+ HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
+ hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
+ }
+}
+
+static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
+ if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
+ hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
+ } else {
+ hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
+ }
+}
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
void hvx_mul_f32(const uint8_t * restrict src0,
uint8_t * restrict dst,
const int num_elems);
void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
-void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale);
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
+static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[2].fd;
+ rsp_bufs[0].ptr = bufs[2].ptr;
+ rsp_bufs[0].offset = bufs[2].offset;
+ rsp_bufs[0].size = bufs[2].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.src1 = req->src1;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ octx.dst.data = (uint32_t) bufs[2].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_get_rows(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
static void proc_matmul_id_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
uint32_t n_bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
- int write_idx = (n_bufs == 4) ? 3 : 2;
+ int write_idx = n_bufs - 1;
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[write_idx].fd;
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
+static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
+ struct dspqueue_buffer rsp_bufs[1];
+
+ // We had written to the output buffer, we'd also need to flush it
+ rsp_bufs[0].fd = bufs[2].fd;
+ rsp_bufs[0].ptr = bufs[2].ptr;
+ rsp_bufs[0].offset = bufs[2].offset;
+ rsp_bufs[0].size = bufs[2].size;
+ rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ // Setup Op context
+ struct htp_ops_context octx = { 0 };
+ octx.ctx = ctx;
+ octx.src0 = req->src0;
+ octx.src1 = req->src1;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ octx.dst.data = (uint32_t) bufs[2].ptr;
+ octx.n_threads = ctx->n_threads;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_set_rows(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+ send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
+}
+
+static void proc_flash_attn_ext_req(struct htp_context * ctx,
+ struct htp_general_req * req,
+ struct dspqueue_buffer * bufs,
+ uint32_t n_bufs) {
+ // Setup Op context
+ struct htp_ops_context octx;
+ memset(&octx, 0, sizeof(octx));
+
+ octx.ctx = ctx;
+ octx.n_threads = ctx->n_threads;
+
+ octx.src0 = req->src0;
+ octx.src1 = req->src1;
+ octx.src2 = req->src2;
+ octx.src3 = req->src3;
+ octx.src4 = req->src4;
+ octx.dst = req->dst;
+ octx.flags = req->flags;
+ octx.op = req->op;
+
+ memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
+
+ // Update data pointers
+ octx.src0.data = (uint32_t) bufs[0].ptr;
+ octx.src1.data = (uint32_t) bufs[1].ptr;
+ octx.src2.data = (uint32_t) bufs[2].ptr;
+
+ int last_buf = 3;
+
+ if (octx.src3.ne[0]) {
+ octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid
+ }
+
+ if (octx.src4.ne[0]) {
+ octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid
+ }
+
+ octx.dst.data = (uint32_t) bufs[last_buf].ptr;
+
+ struct profile_data prof;
+ profile_start(&prof);
+
+ uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
+ if (vtcm_acquire(ctx) == AEE_SUCCESS) {
+ rsp_status = op_flash_attn_ext(&octx);
+ vtcm_release(ctx);
+ }
+
+ profile_stop(&prof);
+
+ struct dspqueue_buffer rsp_buf = bufs[last_buf];
+ rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
+ DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
+
+ send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
+}
+
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
struct htp_context * ctx = (struct htp_context *) context;
break;
case HTP_OP_RMS_NORM:
+ case HTP_OP_SCALE:
if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list");
continue;
proc_rope_req(ctx, &req, bufs, n_bufs);
break;
+ case HTP_OP_FLASH_ATTN_EXT:
+ if (!(n_bufs >= 4 && n_bufs <= 6)) {
+ FARF(ERROR, "Bad flash-attn-ext-req buffer list");
+ continue;
+ }
+ proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs);
+ break;
+
+ case HTP_OP_SET_ROWS:
+ if (n_bufs != 3) {
+ FARF(ERROR, "Bad set-rows-req buffer list");
+ continue;
+ }
+ proc_set_rows_req(ctx, &req, bufs);
+ break;
+
+ case HTP_OP_GET_ROWS:
+ if (n_bufs != 3) {
+ FARF(ERROR, "Bad get-rows-req buffer list");
+ continue;
+ }
+ proc_get_rows_req(ctx, &req, bufs);
+ break;
+
default:
FARF(ERROR, "Unknown Op %u", req.op);
break;
#include "hvx-utils.h"
#include "ops-utils.h"
+#define MM_SPAD_SRC0_NROWS 16
+#define MM_SPAD_SRC1_NROWS 16
+#define MM_SPAD_DST_NROWS 2
+
struct htp_matmul_type {
const char * type;
void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
- void (*vec_dot_rx2)(const int n,
- float * restrict s,
- const void * restrict vx,
- uint32_t vx_row_size,
- const void * restrict vy);
+ void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
};
typedef struct {
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
}
-#if 1
-static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
- if (0) {
- float rsum = 0;
- const __fp16 * restrict vx = (const __fp16 * restrict) x;
- const float * restrict vy = (const float * restrict) y;
+static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const HVX_Vector * restrict x = (const HVX_Vector *) vx;
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy;
- for (uint32_t i = 0; i < n; i++) {
- rsum += (float)vx[i] * vy[i];
- }
- *s = rsum;
- return;
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
- const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
- const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y;
+ if (nloe) {
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
- uint32_t nv0 = n / 64; // num full fp16 hvx vectors
- uint32_t nv1 = n % 64; // leftover elements
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+ hvx_vec_store_u(&s[0], 4, rsum);
+}
+
+static void vec_dot_f16_f16_aa_rx2(const int n,
+ float * restrict s,
+ const void * restrict vx,
+ uint32_t vx_row_size,
+ const void * restrict vy) {
+ const HVX_Vector * restrict x0 = (const HVX_Vector *) vx;
+ const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size);
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy;
+
+ uint32_t nvec = n / VLEN_FP16;
+ uint32_t nloe = n % VLEN_FP16;
+
+ HVX_Vector rsum0 = Q6_V_vsplat_R(0);
+ HVX_Vector rsum1 = Q6_V_vsplat_R(0);
- // for some reason we need volatile here so that the compiler doesn't try anything funky
- volatile HVX_Vector rsum = Q6_V_vsplat_R(0);
- float r_sum_scalar = 0.0f;
uint32_t i = 0;
- for (i = 0; i < nv0; i++) {
- HVX_VectorPair yp = vy[i];
+ #pragma unroll(2)
+ for (i = 0; i < nvec; i++) {
+ HVX_Vector y_hf = y[i];
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf);
- HVX_Vector x = vx[i];
- HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
+ rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
+ rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
+ }
- //NOTE: need volatile here to prevent compiler optimization
- // Seem compiler cannot guarantee read-after-write??
- volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
- volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
+ if (nloe) {
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
+ HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
+
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
- HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo);
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
+ rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
+ rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
}
- if (nv1) {
- // HVX_VectorPair yp = vy[i];
+ rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0));
+ rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1));
+ HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4);
- // HVX_Vector x = vx[i];
- // HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
+ hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+}
- // if (nv1 >= 32) {
- // volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
- // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
- // nv1 -= 32;
- // }
+static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const HVX_UVector * restrict x = (const HVX_UVector *) vx;
+ const HVX_UVector * restrict y = (const HVX_UVector *) vy;
- // rsum = hvx_vec_qf32_reduce_sum(rsum);
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
- // if (nv1) {
- // volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
- // HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
- // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
- // }
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
- //process the remainder using scalar loop
- rsum = hvx_vec_qf32_reduce_sum(rsum);
- const __fp16 * restrict sx = (const __fp16 * restrict) x;
- const float * restrict sy = (const float * restrict) y;
+ uint32_t i = 0;
- for (uint32_t i = nv0 * 64; i < n; i++) {
- r_sum_scalar += (float) sx[i] * sy[i];
- }
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ if (nloe) {
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
- // hvx_vec_dump_fp16("X", x);
- // hvx_vec_dump_fp16("Y", y);
- // hvx_vec_dump_fp32("SUM", Q6_Vsf_equals_Vqf32(sum));
- // hvx_vec_dump_fp32("RSUM", Q6_Vsf_equals_Vqf32(rsum));
- } else {
- rsum = hvx_vec_qf32_reduce_sum(rsum);
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
- *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar;
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+ hvx_vec_store_u(&s[0], 4, rsum);
+}
-# ifdef HTP_DEBUG
- {
- float rsum = 0;
- const __fp16 * restrict vx = (const __fp16 * restrict) x;
- const float * restrict vy = (const float * restrict) y;
+static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+ const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
+ const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
- for (uint32_t i = 0; i < n; i++) {
- rsum += vx[i] * vy[i];
- }
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
- float diff = fabs(*s - rsum);
- if (diff > 0.001) {
- FARF(HIGH, "vec-dot-f16-missmatch: %u (%u:%u) expected %.6f got %.6f\n", n, nv0, nv1, rsum, *s);
- // htp_dump_f16("x", vx, n);
- // htp_dump_f32("y", vy, n);
- }
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(2)
+ for (i = 0; i < nvec; i++) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
-# endif
-}
-#else
-static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
- const uint32_t fk = 64;
- const uint32_t nb = n / fk;
- assert(n % fk == 0);
- assert(nb % 4 == 0);
+ if (nloe) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
- const uint32_t x_blk_size = 2 * fk; // fp16
- const uint32_t y_blk_size = 4 * fk; // fp32
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
- // Row sum (qf32)
- HVX_Vector rsum0 = Q6_V_vsplat_R(0);
- HVX_Vector rsum1 = Q6_V_vsplat_R(0);
- HVX_Vector rsum2 = Q6_V_vsplat_R(0);
- HVX_Vector rsum3 = Q6_V_vsplat_R(0);
-
- for (uint32_t i = 0; i < nb; i += 4) {
- HVX_Vector_x4 vx = hvx_vec_load_x4_f16(x + (i * x_blk_size));
- HVX_Vector_x4 vy = hvx_vec_load_x4_f32_as_f16(y + (i * y_blk_size));
-
- HVX_VectorPair fa0 = Q6_Wqf32_vmpy_VhfVhf(vx.v[0], vy.v[0]);
- HVX_VectorPair fa1 = Q6_Wqf32_vmpy_VhfVhf(vx.v[1], vy.v[1]);
- HVX_VectorPair fa2 = Q6_Wqf32_vmpy_VhfVhf(vx.v[2], vy.v[2]);
- HVX_VectorPair fa3 = Q6_Wqf32_vmpy_VhfVhf(vx.v[3], vy.v[3]);
-
- rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa0), Q6_V_hi_W(fa0)));
- rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa1), Q6_V_hi_W(fa1)));
- rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa2), Q6_V_hi_W(fa2)));
- rsum3 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum3, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa3), Q6_V_hi_W(fa3)));
+ // Zero-out unused elements
+ // Note that we need to clear both x and y because they may contain NANs
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ x_hf = Q6_V_vand_QV(bmask, x_hf);
+ y_hf = Q6_V_vand_QV(bmask, y_hf);
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
- // Reduce and convert into fp32
- rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum1);
- rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, rsum3);
- HVX_Vector rsum = hvx_vec_qf32_reduce_sum(Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum2));
- hvx_vec_store_u(s, 4, Q6_Vsf_equals_Vqf32(rsum));
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+ hvx_vec_store_u(&s[0], 4, rsum);
}
-#endif
-#define htp_matmul_preamble \
+#define htp_matmul_tensors_preamble \
+ struct htp_tensor * restrict src0 = &octx->src0; \
+ struct htp_tensor * restrict src1 = &octx->src1; \
+ struct htp_tensor * restrict src2 = &octx->src2; \
+ struct htp_tensor * restrict dst = &octx->dst; \
+ struct htp_spad * restrict src0_spad = &octx->src0_spad; \
+ struct htp_spad * restrict src1_spad = &octx->src1_spad; \
+ struct htp_spad * restrict dst_spad = &octx->dst_spad; \
+ \
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
const uint32_t ne02 = src0->ne[2]; \
const uint32_t ne12 = src1->ne[2]; \
const uint32_t ne13 = src1->ne[3]; \
\
+ const uint32_t ne20 = src2->ne[0]; \
+ const uint32_t ne21 = src2->ne[1]; \
+ const uint32_t ne22 = src2->ne[2]; \
+ const uint32_t ne23 = src2->ne[3]; \
+ \
const uint32_t ne0 = dst->ne[0]; \
const uint32_t ne1 = dst->ne[1]; \
const uint32_t ne2 = dst->ne[2]; \
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
-// q8x4 src1 tensor is already in VTCM spad
-static void matmul(struct htp_matmul_type * mt,
- struct htp_tensor * restrict src0,
- struct htp_tensor * restrict src1,
- struct htp_tensor * restrict dst,
- struct htp_spad * restrict src0_spad,
- struct htp_spad * restrict src1_spad,
- struct htp_spad * restrict dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+#define htp_matmul_preamble \
+ htp_matmul_tensors_preamble; \
+ dma_queue *dma_queue = octx->ctx->dma[ith]; \
+ uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+
+// *** matmul with support for 4d tensors and full broadcasting
+
+static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+ htp_matmul_preamble;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ assert(ne12 % ne02 == 0);
+ assert(ne13 % ne03 == 0);
+
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
+ const uint32_t nr0 = ne0;
+
+ // This is the size of the rest of the dimensions of the result
+ const uint32_t nr1 = ne1 * ne2 * ne3;
+
+ // distribute the thread work across the inner or outer loop based on which one is larger
+ uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
+ uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+
+ // The number of elements in each chunk
+ const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
+ const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
+
+ uint32_t current_chunk = ith;
+
+ const uint32_t ith0 = current_chunk % nchunk0;
+ const uint32_t ith1 = current_chunk / nchunk0;
+
+ const uint32_t ir0_start = dr0 * ith0;
+ const uint32_t ir0_end = MIN(ir0_start + dr0, nr0);
+
+ const uint32_t ir1_start = dr1 * ith1;
+ const uint32_t ir1_end = MIN(ir1_start + dr1, nr1);
+
+ // no work for this thread
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
+ return;
+ }
+
+ // block-tiling attempt
+ const uint32_t blck_0 = 64;
+ const uint32_t blck_1 = 64;
+
+ for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
+ for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
+ for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
+ const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1);
+ const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1);
+ const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
+
+ // broadcast src0 into src1
+ const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3);
+ const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2);
+
+ const uint32_t i1 = i11;
+ const uint32_t i2 = i12;
+ const uint32_t i3 = i13;
+
+ const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
+ const uint8_t * restrict src1_col = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
+ float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
+
+ const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
+ for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
+ const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
+ mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col);
+ }
+ }
+ }
+ }
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
+ src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// src1 tensor is already in VTCM spad
+static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
htp_matmul_preamble;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const size_t dst_row_size = nb1;
const size_t src0_row_size = nb01;
- const size_t src1_row_size = q8x4x2_row_size(ne10);
+ const size_t src1_row_size = nb11;
- const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+ const size_t src0_stride = src0_spad->stride;
+ const size_t src1_stride = src1_spad->stride;
// Per-thread VTCM scratchpads for all tensors
// Note that the entire src1 tensor is already in VTCM
#pragma unroll(4)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const int is0 = (ir0 - src0_start_row);
- if (is0 >= HTP_SPAD_SRC0_NROWS) {
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
break;
}
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 2);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
}
// Process src0 rows
#pragma unroll(2)
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
- const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
- mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+ mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col);
}
// Prefetch next (n + spad_nrows) row
- const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
- const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
if (pr0 < src0_end_row_x2) {
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 2);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
}
}
if (src0_end_row != src0_end_row_x2) {
uint32_t ir0 = src0_end_row_x2;
const int is0 = (ir0 - src0_start_row);
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 1);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
#pragma unroll(2)
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
- const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
}
}
// q8x4x2 src1 tensor is already in VTCM spad
-static void matvec(struct htp_matmul_type * mt,
- struct htp_tensor * restrict src0,
- struct htp_tensor * restrict src1,
- struct htp_tensor * restrict dst,
- struct htp_spad * restrict src0_spad,
- struct htp_spad * restrict src1_spad,
- struct htp_spad * restrict dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
htp_matmul_preamble;
const uint32_t src0_nrows = ne01;
const size_t dst_row_size = nb1;
const size_t src0_row_size = nb01;
- const size_t src1_row_size = q8x4x2_row_size(ne10);
+ const size_t src1_row_size = nb11;
- const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+ const size_t src0_stride = src0_spad->stride;
+ const size_t src1_stride = src1_spad->stride;
// Per-thread VTCM scratchpads for all tensors
// Note that the entire src1 tensor is already in VTCM
#pragma unroll(2)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint32_t is0 = (ir0 - src0_start_row);
- if (is0 >= HTP_SPAD_SRC0_NROWS) {
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
break;
}
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 2);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
}
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
- mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_row_size_padded, src1_col);
+ mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col);
// Prefetch next (n + spad_nrows) row
- const uint32_t pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
- const uint32_t is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+ const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
if (pr0 < src0_end_row_x2) {
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 2);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
}
}
if (src0_end_row != src0_end_row_x2) {
const uint32_t ir0 = src0_end_row_x2;
const uint32_t is0 = (ir0 - src0_start_row);
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 1);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
}
uint32_t i2;
};
-// q8x4 src1 tensor is already in VTCM spad
-static void matmul_id(struct htp_matmul_type * mt,
- struct htp_tensor * restrict src0,
- struct htp_tensor * restrict src1,
- struct htp_tensor * restrict ids,
- struct htp_tensor * restrict dst,
- struct htp_spad * restrict src0_spad,
- struct htp_spad * restrict src1_spad,
- struct htp_spad * restrict src2_spad,
- struct htp_spad * restrict dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+// src1 tensor is already in VTCM spad
+static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
htp_matmul_preamble;
+ struct htp_tensor * restrict ids = &octx->src2;
+ struct htp_spad * restrict src2_spad = &octx->src2_spad;
+
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
#pragma unroll(4)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const int is0 = (ir0 - src0_start_row);
- if (is0 >= HTP_SPAD_SRC0_NROWS) {
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
break;
}
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
}
// Prefetch next (n + spad_nrows) row
- const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
- const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
if (pr0 < src0_end_row_x2) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
src0_row_size_padded, src0_row_size, 2);
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-// q8x4 src1 tensor is already in VTCM spad
-static void matvec_id(struct htp_matmul_type * mt,
- struct htp_tensor * restrict src0,
- struct htp_tensor * restrict src1,
- struct htp_tensor * restrict src2,
- struct htp_tensor * restrict dst,
- struct htp_spad * restrict src0_spad,
- struct htp_spad * restrict src1_spad,
- struct htp_spad * restrict src2_spad,
- struct htp_spad * restrict dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+// src1 tensor is already in VTCM spad
+static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
htp_matmul_preamble;
+ struct htp_tensor * restrict ids = &octx->src2;
+ struct htp_spad * restrict src2_spad = &octx->src2_spad;
+
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
#pragma unroll(4)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const int is0 = (ir0 - src0_start_row);
- if (is0 >= HTP_SPAD_SRC0_NROWS) {
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
break;
}
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
// Prefetch next (n + spad_nrows) row
- const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
- const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
if (pr0 < src0_end_row_x2) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
src0_row_size_padded, src0_row_size, 2);
dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-// *** matmul in fp16
-
-static void matmul_f16_f32(struct htp_tensor * restrict src0,
- struct htp_tensor * restrict src1,
- struct htp_tensor * restrict dst,
- struct htp_spad * restrict src0_spad,
- struct htp_spad * restrict src1_spad,
- struct htp_spad * restrict dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
- htp_matmul_preamble;
-
- uint64_t t1, t2;
- t1 = HAP_perf_get_qtimer_count();
-
- assert(ne12 % ne02 == 0);
- assert(ne13 % ne03 == 0);
-
- // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
- const uint32_t nr0 = ne0;
-
- // This is the size of the rest of the dimensions of the result
- const uint32_t nr1 = ne1 * ne2 * ne3;
-
- // distribute the thread work across the inner or outer loop based on which one is larger
- uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
- uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
-
- // The number of elements in each chunk
- const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
- const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
-
- uint32_t current_chunk = ith;
-
- const uint32_t ith0 = current_chunk % nchunk0;
- const uint32_t ith1 = current_chunk / nchunk0;
-
- const uint32_t ir0_start = dr0 * ith0;
- const uint32_t ir0_end = MIN(ir0_start + dr0, nr0);
-
- const uint32_t ir1_start = dr1 * ith1;
- const uint32_t ir1_end = MIN(ir1_start + dr1, nr1);
-
- // broadcast factors
- const uint32_t r2 = ne12 / ne02;
- const uint32_t r3 = ne13 / ne03;
-
- // no work for this thread
- if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
- return;
- }
-
- // block-tiling attempt
- const uint32_t blck_0 = 64;
- const uint32_t blck_1 = 64;
-
- __attribute__((aligned(128))) float tmp[64];
-
- for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
- for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
- for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
- const uint32_t i13 = (ir1 / (ne12 * ne1));
- const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
- const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
-
- // broadcast src0 into src1
- const uint32_t i03 = i13 / r3;
- const uint32_t i02 = i12 / r2;
-
- const uint32_t i1 = i11;
- const uint32_t i2 = i12;
- const uint32_t i3 = i13;
-
- const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
- const uint8_t * restrict src1_col =
- (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
- float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
-
- const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
- for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
- // Use nb01 stride for non-contiguous src0 support
- const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
- vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col);
- }
-
- hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0);
- }
- }
- }
-
- t2 = HAP_perf_get_qtimer_count();
-
- FARF(HIGH, "matmul-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
- src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
- src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
- (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
-}
-
// *** dynamic quant
static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
for (uint32_t i = 0; i < nb; i++) {
#if FP32_QUANTIZE_GROUP_SIZE == 32
- quantize_block_fp32_q8x1(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
- t_d + (i * 2 + 0) * dblk_size / 2);
- quantize_block_fp32_q8x1(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
- t_d + (i * 2 + 1) * dblk_size / 2);
+ quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+ quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
#elif FP32_QUANTIZE_GROUP_SIZE == 64
- quantize_block_fp32_q8x2(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
- t_d + (i * 2 + 0) * dblk_size / 2);
- quantize_block_fp32_q8x2(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
- t_d + (i * 2 + 1) * dblk_size / 2);
+ quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+ quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
#elif FP32_QUANTIZE_GROUP_SIZE == 128
- quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
- t_d + (i * 2 + 0) * dblk_size / 2);
- quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
- t_d + (i * 2 + 1) * dblk_size / 2);
+ quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+ quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
#else
#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
#endif
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
+static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
+ uint32_t nrows_per_thread, uint32_t dst_stride) {
+
+ uint64_t t1 = HAP_perf_get_qtimer_count();
+
+ const uint32_t ne0 = src->ne[0];
+ const uint32_t ne1 = src->ne[1];
+ const uint32_t ne2 = src->ne[2];
+ const uint32_t ne3 = src->ne[3];
+
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
+
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
+
+ const size_t src_row_size = ne0 * sizeof(float);
+ const size_t src_stride = src->nb[1];
+
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
+
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
+ htp_l2fetch(src_data, 2, src_row_size, src_stride);
+ hvx_copy_fp16_fp32_au(dst_data, src_data, ne0);
+
+ dst_data += dst_stride;
+ src_data += src_stride;
+ }
+
+ uint64_t t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
+ ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// TODO just a plain copy that should be done via the DMA during the Op setup
+static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
+ uint32_t nrows_per_thread, uint32_t dst_stride) {
+
+ uint64_t t1 = HAP_perf_get_qtimer_count();
+
+ const uint32_t ne0 = src->ne[0];
+ const uint32_t ne1 = src->ne[1];
+ const uint32_t ne2 = src->ne[2];
+ const uint32_t ne3 = src->ne[3];
+
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
+
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
+
+ const size_t src_row_size = ne0 * sizeof(float);
+ const size_t src_stride = src->nb[1];
+
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
+
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
+ htp_l2fetch(src_data, 2, src_row_size, src_stride);
+ hvx_copy_fp16_au(dst_data, src_data, ne0);
+
+ dst_data += dst_stride;
+ src_data += src_stride;
+ }
+
+ uint64_t t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
+ ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
}
-// ** matmul callbacks for worker_pool
+static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+ quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
+}
+
+static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+ quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
+}
+
+// ** matmul/matvec callbacks for worker_pool
-static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
- matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_2d(&mt, octx, n, i);
}
-static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
- matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_2d(&mt, octx, n, i);
}
-static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
- matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_2d(&mt, octx, n, i);
}
-static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
- matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_2d(&mt, octx, n, i);
}
-static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
- matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_2d(&mt, octx, n, i);
}
-static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
- matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_2d(&mt, octx, n, i);
}
-static void htp_matmul_f16_f32(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
- matmul_f16_f32(&octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+
+ struct htp_matmul_type mt;
+ mt.type = "f16-f16";
+ mt.vec_dot = vec_dot_f16_f16_aa;
+ mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
+
+ matvec_2d(&mt, octx, n, i);
+}
+
+static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+
+ struct htp_matmul_type mt;
+ mt.type = "f16-f16";
+ mt.vec_dot = vec_dot_f16_f16_aa;
+ mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
+
+ matmul_2d(&mt, octx, n, i);
+}
+
+static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+
+ struct htp_matmul_type mt;
+ mt.type = "f16-f32";
+ mt.vec_dot = vec_dot_f16_f32_uu;
+
+ matmul_4d(&mt, octx, n, i);
+}
+
+static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+
+ struct htp_matmul_type mt;
+ mt.type = "f16-f16";
+ mt.vec_dot = vec_dot_f16_f16_uu;
+
+ matmul_4d(&mt, octx, n, i);
}
// ** matmul-id callbacks for worker_pool
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
- matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_id(&mt, octx, n, i);
}
static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
- matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_id(&mt, octx, n, i);
}
static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
- matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_id(&mt, octx, n, i);
}
static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
- matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_id(&mt, octx, n, i);
}
static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
- matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_id(&mt, octx, n, i);
}
static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
- matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_id(&mt, octx, n, i);
}
// ** main matmul entry point
-int op_matmul(struct htp_ops_context * octx) {
- const struct htp_tensor * src0 = &octx->src0;
- const struct htp_tensor * src1 = &octx->src1;
- struct htp_tensor * dst = &octx->dst;
+static inline bool htp_is_permuted(const struct htp_tensor * t) {
+ return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
+}
- htp_matmul_preamble;
+int op_matmul(struct htp_ops_context * octx) {
+ htp_matmul_tensors_preamble;
const char * op_type;
op_type = "q4x4x2-fp32";
quant_job_func = htp_quantize_fp32_q8x4x2;
if (src1_nrows > 1) {
- matmul_job_func = htp_matmul_q4x4x2_q8x4x2;
+ matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2;
} else {
- matmul_job_func = htp_matvec_q4x4x2_q8x4x2;
+ matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2;
}
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
op_type = "q8x4x2-fp32";
quant_job_func = htp_quantize_fp32_q8x4x2;
if (src1_nrows > 1) {
- matmul_job_func = htp_matmul_q8x4x2_q8x4x2;
+ matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2;
} else {
- matmul_job_func = htp_matvec_q8x4x2_q8x4x2;
+ matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2;
}
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
op_type = "mxfp4x4x2-f32";
quant_job_func = htp_quantize_fp32_q8x4x2;
if (src1_nrows > 1) {
- matmul_job_func = htp_matmul_mxfp4x4x2_q8x4x2;
+ matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2;
} else {
- matmul_job_func = htp_matvec_mxfp4x4x2_q8x4x2;
+ matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2;
}
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
break;
case HTP_TYPE_F16:
- op_type = "f16-f32";
- quant_job_func = NULL; // htp_quantize_f32_f16;
- matmul_job_func = htp_matmul_f16_f32;
-
- // For all tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size, 256);
- octx->src1_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC1_NROWS * src1_row_size, 256);
-
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
- octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
-
- need_quant = false;
+ {
+ // Try optimized f16-f16 path first (src1 in VTCM)
+ const size_t f16_src1_row_size = htp_round_up(ne10 * 2, 128);
+ const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256);
+ const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
+ const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
+
+ const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
+
+ // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
+ // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
+ const bool is_batched = (ne02 > 1) || (ne03 > 1);
+ const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
+
+ if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
+ // Optimized path
+ op_type = "f16-f16";
+ quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16;
+ if (src1_nrows > 1) {
+ matmul_job_func = htp_matmul_2d_f16_f16;
+ } else {
+ matmul_job_func = htp_matvec_2d_f16_f16;
+ }
+
+ src1_row_size = f16_src1_row_size; // row size post quantization
+
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
+
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
+ } else {
+ // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
+ quant_job_func = NULL;
+ if (src1->type == HTP_TYPE_F32) {
+ op_type = "f16-f32";
+ matmul_job_func = htp_matmul_4d_f16_f32;
+ } else {
+ op_type = "f16-f16";
+ matmul_job_func = htp_matmul_4d_f16_f16;
+ }
+
+ src1_row_size = nb11; // original row size in DDR
+
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
+ octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
+
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
+
+ // Init fastdiv for matmul_4d (supports broadcasting)
+ octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
+ octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
+ octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
+ octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
+
+ need_quant = false;
+ }
+ }
break;
default:
octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even
+ octx->src0_spad.stride = src0_row_size_padded;
+ octx->src1_spad.stride = src1_row_size;
+
if (need_quant) {
// Run quant jobs
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
// ** main matmul-id entry point
int op_matmul_id(struct htp_ops_context * octx) {
- const struct htp_tensor * src0 = &octx->src0;
- const struct htp_tensor * src1 = &octx->src1;
- const struct htp_tensor * ids = &octx->src2;
- struct htp_tensor * dst = &octx->dst;
+ htp_matmul_tensors_preamble;
- htp_matmul_preamble;
+ struct htp_tensor * restrict ids = &octx->src2;
const char * op_type;
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
--- /dev/null
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+# define FARF_HIGH 1
+#endif
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+#define set_rows_preamble \
+ const uint32_t ne00 = octx->src0.ne[0]; \
+ const uint32_t ne01 = octx->src0.ne[1]; \
+ const uint32_t ne02 = octx->src0.ne[2]; \
+ const uint32_t ne03 = octx->src0.ne[3]; \
+ \
+ const uint32_t ne10 = octx->src1.ne[0]; \
+ const uint32_t ne11 = octx->src1.ne[1]; \
+ const uint32_t ne12 = octx->src1.ne[2]; \
+ \
+ const uint32_t nb01 = octx->src0.nb[1]; \
+ const uint32_t nb02 = octx->src0.nb[2]; \
+ const uint32_t nb03 = octx->src0.nb[3]; \
+ \
+ const uint32_t nb10 = octx->src1.nb[0]; \
+ const uint32_t nb11 = octx->src1.nb[1]; \
+ const uint32_t nb12 = octx->src1.nb[2]; \
+ \
+ const uint32_t nb1 = octx->dst.nb[1]; \
+ const uint32_t nb2 = octx->dst.nb[2]; \
+ const uint32_t nb3 = octx->dst.nb[3]; \
+ \
+ const uint32_t ne1 = octx->dst.ne[1]; \
+ \
+ const uint32_t nr = ne01;
+
+static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+ set_rows_preamble;
+
+ // parallelize by rows of src0
+ const uint32_t dr = octx->src0_nrows_per_thread;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
+
+ const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
+
+ for (uint32_t i03 = 0; i03 < ne03; ++i03) {
+ for (uint32_t i02 = 0; i02 < ne02; ++i02) {
+ for (uint32_t i = ir0; i < ir1; ++i) {
+ const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
+ const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+ const uint32_t i10 = i;
+
+ const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
+
+ uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
+ if (i1 >= ne1) {
+ // ignore invalid indices
+ continue;
+ }
+
+ const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
+ const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
+
+ // copy row
+ hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
+ }
+ }
+ }
+
+ return HTP_STATUS_OK;
+}
+
+static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
+ set_rows_preamble;
+
+ // parallelize by rows of src0
+ const uint32_t dr = octx->src0_nrows_per_thread;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
+
+ const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
+
+ for (uint32_t i03 = 0; i03 < ne03; ++i03) {
+ for (uint32_t i02 = 0; i02 < ne02; ++i02) {
+ for (uint32_t i = ir0; i < ir1; ++i) {
+ const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
+ const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
+ const uint32_t i10 = i;
+
+ const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
+
+ uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
+ if (i1 >= ne1) {
+ // ignore invalid indices
+ continue;
+ }
+
+ const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
+ uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
+
+ hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00);
+ }
+ }
+ }
+
+ return HTP_STATUS_OK;
+}
+
+static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
+ set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
+}
+
+static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
+ set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
+}
+
+int op_set_rows(struct htp_ops_context * octx) {
+ set_rows_preamble;
+
+ if (octx->src0.type != HTP_TYPE_F32) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
+ return HTP_STATUS_OK;
+ }
+
+ octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
+ octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
+
+ const uint32_t n_jobs = MIN(nr, octx->n_threads);
+ octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
+
+ switch(octx->dst.type) {
+ case HTP_TYPE_F32:
+ worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
+ break;
+ case HTP_TYPE_F16:
+ worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
+ break;
+ default:
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ return HTP_STATUS_OK;
+}
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
(const uint8_t *) mp_f32, slope);
} else {
- hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale);
+ hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
if (mp_f32) {
if (softmax_ctx->use_f16) {
for (int i = 0; i < ne00; ++i) {
float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
sum = sum > 0.0 ? (1.0 / sum) : 1;
- hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum);
+ hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
}
}
}
}
}
+static void scale_htp_f32(const float * restrict src,
+ float * restrict dst,
+ uint8_t * restrict spad,
+ const uint32_t num_rows,
+ const uint32_t row_elems,
+ const size_t row_size,
+ int32_t * op_params,
+ int opt_path) {
+ float scale = 0.f;
+ float bias = 0.f;
+ memcpy(&scale, &op_params[0], sizeof(float));
+ memcpy(&bias, &op_params[1], sizeof(float));
+
+ for (uint32_t ir = 0; ir < num_rows; ir++) {
+ const float * restrict src_local = src + (ir * row_elems);
+ float * restrict dst_local = dst + (ir * row_elems);
+
+ if (ir + 1 < num_rows) {
+ htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
+ }
+
+ hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
+ }
+}
+
static void rms_norm_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const float mean = sum / row_elems;
const float scale = 1.0f / sqrtf(mean + epsilon);
- hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale);
+ hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
}
}
}
case HTP_OP_RMS_NORM:
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
+ case HTP_OP_SCALE:
+ scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
+ break;
default:
break;
unary_op_func = unary_job_dispatcher_f32;
op_type = "rmsnorm-f32";
break;
+ case HTP_OP_SCALE:
+ unary_op_func = unary_job_dispatcher_f32;
+ op_type = "scale-f32";
+ break;
default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
device="HTP0"
[ "$D" != "" ] && device="$D"
-verbose=""
-[ "$V" != "" ] && verbose="$V"
+verbose=
+[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v"
+
+experimental=
+[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
+
+profile=
+[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v"
opmask=
[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK"
cd $basedir; \
LD_LIBRARY_PATH=$basedir/$branch/lib \
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
- $ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
+ $ndev $nhvx $opmask $verbose $experimental $profile ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
- --batch-size 128 -ngl 99 $@ \
+ --batch-size 128 -ngl 99 $cli_opts $@ \
"