--- /dev/null
+#pragma clang diagnostic ignored "-Wunused-variable"
+#pragma clang diagnostic ignored "-Wunused-function"
+#pragma clang diagnostic ignored "-Wunused-but-set-variable"
+
+#ifdef HTP_DEBUG
+# define FARF_HIGH 1
+#endif
+#include <HAP_farf.h>
+#include <HAP_mem.h>
+#include <HAP_perf.h>
+#include <hexagon_protos.h>
+#include <hexagon_types.h>
+#include <math.h>
+#include <string.h>
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+#include "htp-ctx.h"
+#include "htp-dma.h"
+#include "htp-msg.h"
+#include "htp-ops.h"
+#include "hvx-utils.h"
+#include "ops-utils.h"
+
+// Dot product of FP32 and FP16 vectors, accumulating to float
+static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
+ const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ if (nloe) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
+
+ // Zero-out unused elements
+ // Note that we need to clear both x and y because they may contain NANs
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ x_hf = Q6_V_vand_QV(bmask, x_hf);
+ y_hf = Q6_V_vand_QV(bmask, y_hf);
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+
+ hvx_vec_store_u(r, 4, rsum);
+}
+
+// Dot product of two F16 vectors, accumulating to float
+static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
+ const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
+ const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ HVX_Vector y_hf = vy[i];
+ HVX_Vector x_hf = vx[i];
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ if (nloe) {
+ HVX_Vector y_hf = vy[i];
+
+ // Load x (fp16) and zero-out unused elements
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+ hvx_vec_store_u(r, 4, rsum);
+}
+
+// MAD: y (F32) += x (F16) * v (float)
+static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
+ const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
+ HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
+
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ HVX_Vector S = hvx_vec_splat_fp16(s);
+
+ uint32_t i = 0;
+ #pragma unroll(4)
+ for (i = 0; i < nvec; ++i) {
+ // Multiply x * s -> pair of F32 vectors
+ HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
+ ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
+ ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
+ }
+
+ if (nloe) {
+ HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
+
+ HVX_Vector xs = Q6_V_lo_W(xs_p);
+ i = 2 * i; // index for ptr_y
+
+ if (nloe >= 32) {
+ ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
+ nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
+ }
+
+ if (nloe) {
+ HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
+ hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
+ }
+ }
+}
+
+#define FLASH_ATTN_BLOCK_SIZE 128
+
+static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
+ const struct htp_tensor * q = &octx->src0;
+ const struct htp_tensor * k = &octx->src1;
+ const struct htp_tensor * v = &octx->src2;
+ const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
+ const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
+ struct htp_tensor * dst = &octx->dst;
+
+ const uint32_t neq0 = q->ne[0];
+ const uint32_t neq1 = q->ne[1];
+ const uint32_t neq2 = q->ne[2];
+ const uint32_t neq3 = q->ne[3];
+
+ const uint32_t nek0 = k->ne[0];
+ const uint32_t nek1 = k->ne[1];
+ const uint32_t nek2 = k->ne[2];
+ const uint32_t nek3 = k->ne[3];
+
+ const uint32_t nev0 = v->ne[0];
+ const uint32_t nev1 = v->ne[1];
+ const uint32_t nev2 = v->ne[2];
+ const uint32_t nev3 = v->ne[3];
+
+ const uint32_t nbq1 = q->nb[1];
+ const uint32_t nbq2 = q->nb[2];
+ const uint32_t nbq3 = q->nb[3];
+
+ const uint32_t nbk1 = k->nb[1];
+ const uint32_t nbk2 = k->nb[2];
+ const uint32_t nbk3 = k->nb[3];
+
+ const uint32_t nbv1 = v->nb[1];
+ const uint32_t nbv2 = v->nb[2];
+ const uint32_t nbv3 = v->nb[3];
+
+ const uint32_t ne1 = dst->ne[1];
+ const uint32_t ne2 = dst->ne[2];
+ const uint32_t ne3 = dst->ne[3];
+
+ const uint32_t nb1 = dst->nb[1];
+ const uint32_t nb2 = dst->nb[2];
+ const uint32_t nb3 = dst->nb[3];
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+ float logit_softcap = 0.0f;
+
+ memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
+ memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
+
+ if (logit_softcap != 0) {
+ scale /= logit_softcap;
+ }
+
+ // total rows in q
+ const uint32_t nr = neq1*neq2*neq3;
+
+ const uint32_t dr = (nr + nth - 1) / nth;
+ const uint32_t ir0 = dr * ith;
+ const uint32_t ir1 = MIN(ir0 + dr, nr);
+
+ if (ir0 >= ir1) return;
+
+ dma_queue * dma = octx->ctx->dma[ith];
+
+ const uint32_t DK = nek0;
+ const uint32_t DV = nev0;
+
+ const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
+ const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
+
+ const size_t size_k_row = DK * sizeof(__fp16);
+ const size_t size_v_row = DV * sizeof(__fp16);
+ const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
+
+ const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
+ const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
+
+ const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
+
+ // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
+ uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
+ uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
+ uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
+ uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
+ uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
+
+ const uint32_t n_head = neq2;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ for (uint32_t ir = ir0; ir < ir1; ++ir) {
+ const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
+ const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
+ const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
+
+ const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
+ const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
+
+ const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
+ const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
+
+ // Fetch Q row
+ const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
+ dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
+
+ const uint32_t h = iq2; // head index
+ const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
+
+ float S = 0.0f; // sum
+ float M = -INFINITY; // maximum KQ value
+
+ // Clear accumulator
+ float * VKQ32 = (float *) spad_a;
+ memset(VKQ32, 0, DV * sizeof(float));
+
+ const __fp16 * mp_base = NULL;
+ if (mask) {
+ const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
+ const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
+ mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
+ }
+
+ const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
+
+ // Prefetch first two blocks
+ for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
+ const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
+ const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
+
+ // K
+ const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
+ uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
+ dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
+
+ // V
+ const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
+ uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
+ dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
+
+ // Mask
+ if (mask) {
+ const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
+ uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
+ // Mask is 1D contiguous for this row
+ dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
+ }
+ }
+
+ const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
+
+ for (uint32_t ib = 0; ib < n_blocks; ++ib) {
+ const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
+ const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
+
+ // Wait for DMA
+ uint8_t * k_base = dma_queue_pop(dma).dst; // K
+ uint8_t * v_base = dma_queue_pop(dma).dst; // V
+ __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
+
+ // Inner loop processing the block from VTCM
+ uint32_t ic = 0;
+
+ // Process in blocks of 32 (VLEN_FP32)
+ for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
+ // 1. Compute scores
+ float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
+ for (int j = 0; j < VLEN_FP32; ++j) {
+ const uint32_t cur_ic = ic + j;
+ const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
+ if (q->type == HTP_TYPE_F32) {
+ hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
+ } else {
+ hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
+ }
+ }
+
+ HVX_Vector scores = *(HVX_Vector *) scores_arr;
+
+ // 2. Softcap
+ if (logit_softcap != 0.0f) {
+ scores = hvx_vec_tanh_fp32(scores);
+ scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
+ scores = Q6_Vsf_equals_Vqf32(scores);
+ }
+
+ // 3. Mask
+ if (mask) {
+ const __fp16 * mp = m_base + ic;
+ HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
+
+ HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
+ HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
+
+ HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
+
+ HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
+ HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
+ scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
+ scores = Q6_Vsf_equals_Vqf32(scores);
+ }
+
+ // 4. Online Softmax Update
+ HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
+ float m_block = hvx_vec_get_fp32(v_max);
+
+ float M_old = M;
+ float M_new = (m_block > M) ? m_block : M;
+ M = M_new;
+
+ float ms = expf(M_old - M_new);
+
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+ S = S * ms;
+
+ HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
+ HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
+ HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
+
+ HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
+ float p_sum = hvx_vec_get_fp32(p_sum_vec);
+ S += p_sum;
+
+ // 5. Accumulate V
+ float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
+ *(HVX_Vector*)p_arr = P;
+
+ for (int j = 0; j < VLEN_FP32; ++j) {
+ const uint32_t cur_ic = ic + j;
+ const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
+ hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
+ }
+ }
+
+ // Leftover
+ for (; ic < current_block_size; ++ic) {
+ float s_val;
+ const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
+
+ if (q->type == HTP_TYPE_F32) {
+ hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
+ } else {
+ hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
+ }
+
+ if (logit_softcap != 0.0f) {
+ s_val = logit_softcap * tanhf(s_val);
+ }
+
+ if (mask) {
+ const float m_val = m_base[ic];
+ s_val += slope * m_val;
+ }
+
+ const float Mold = M;
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (s_val > M) {
+ M = s_val;
+ ms = expf(Mold - M);
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+ } else {
+ vs = expf(s_val - M);
+ }
+
+ const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
+
+ hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
+
+ S = S * ms + vs;
+ }
+
+ // Issue DMA for next+1 block (if exists)
+ if (ib + 2 < n_blocks) {
+ const uint32_t next_ib = ib + 2;
+ const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
+ const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
+
+ // K
+ const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
+ dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
+
+ // V
+ const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
+ dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
+
+ // Mask
+ if (mask) {
+ const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
+ dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
+ }
+ }
+ }
+
+ // sinks
+ if (sinks) {
+ const float s = ((float *)((char *) sinks->data))[h];
+
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (s > M) {
+ ms = expf(M - s);
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
+ } else {
+ vs = expf(s - M);
+ }
+
+ S = S * ms + vs;
+ }
+
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
+ hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
+
+ // Store result
+ // dst indices
+ const int i1 = iq1;
+ const int i2 = iq2;
+ const int i3 = iq3;
+
+ // dst is permuted
+ uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
+
+ if (dst->type == HTP_TYPE_F32) {
+ hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
+ } else if (dst->type == HTP_TYPE_F16) {
+ hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
+ }
+ }
+}
+
+static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+ flash_attn_ext_f16_thread(octx, i, n);
+}
+
+int op_flash_attn_ext(struct htp_ops_context * octx) {
+ const struct htp_tensor * q = &octx->src0;
+ const struct htp_tensor * k = &octx->src1;
+ const struct htp_tensor * v = &octx->src2;
+ const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
+ struct htp_tensor * dst = &octx->dst;
+
+ // Check support
+ if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
+ k->type != HTP_TYPE_F16 ||
+ v->type != HTP_TYPE_F16) {
+ return HTP_STATUS_NO_SUPPORT;
+ }
+
+ octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
+ octx->src0_div1 = init_fastdiv_values(q->ne[1]);
+
+ octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
+ octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
+ octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
+ octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
+
+ if (mask) {
+ octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
+ octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
+ }
+
+ size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
+ size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
+ size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
+
+ size_t size_q_block = size_q_row_padded * 1; // single row for now
+ size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
+ size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
+
+ size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
+
+ octx->src0_spad.size_per_thread = size_q_block * 1;
+ octx->src1_spad.size_per_thread = size_k_block * 2;
+ octx->src2_spad.size_per_thread = size_v_block * 2;
+ octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
+ octx->dst_spad.size_per_thread = size_vkq_acc;
+
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
+ octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
+ octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
+
+ size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
+
+ if (octx->ctx->vtcm_size < total_spad) {
+ return HTP_STATUS_VTCM_TOO_SMALL;
+ }
+
+ octx->src0_spad.data = octx->ctx->vtcm_base;
+ octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
+ octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
+ octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
+ octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
+
+ if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
+ worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
+ }
+
+ return HTP_STATUS_OK;
+}
#include "hvx-utils.h"
#include "ops-utils.h"
+#define MM_SPAD_SRC0_NROWS 16
+#define MM_SPAD_SRC1_NROWS 16
+#define MM_SPAD_DST_NROWS 2
+
struct htp_matmul_type {
const char * type;
void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
- void (*vec_dot_rx2)(const int n,
- float * restrict s,
- const void * restrict vx,
- uint32_t vx_row_size,
- const void * restrict vy);
+ void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
};
typedef struct {
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
}
-#if 1
-static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
- if (0) {
- float rsum = 0;
- const __fp16 * restrict vx = (const __fp16 * restrict) x;
- const float * restrict vy = (const float * restrict) y;
+static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const HVX_Vector * restrict x = (const HVX_Vector *) vx;
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy;
- for (uint32_t i = 0; i < n; i++) {
- rsum += (float)vx[i] * vy[i];
- }
- *s = rsum;
- return;
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
+
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
- const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
- const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y;
+ if (nloe) {
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
- uint32_t nv0 = n / 64; // num full fp16 hvx vectors
- uint32_t nv1 = n % 64; // leftover elements
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+ hvx_vec_store_u(&s[0], 4, rsum);
+}
+
+static void vec_dot_f16_f16_aa_rx2(const int n,
+ float * restrict s,
+ const void * restrict vx,
+ uint32_t vx_row_size,
+ const void * restrict vy) {
+ const HVX_Vector * restrict x0 = (const HVX_Vector *) vx;
+ const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size);
+ const HVX_Vector * restrict y = (const HVX_Vector *) vy;
+
+ uint32_t nvec = n / VLEN_FP16;
+ uint32_t nloe = n % VLEN_FP16;
+
+ HVX_Vector rsum0 = Q6_V_vsplat_R(0);
+ HVX_Vector rsum1 = Q6_V_vsplat_R(0);
- // for some reason we need volatile here so that the compiler doesn't try anything funky
- volatile HVX_Vector rsum = Q6_V_vsplat_R(0);
- float r_sum_scalar = 0.0f;
uint32_t i = 0;
- for (i = 0; i < nv0; i++) {
- HVX_VectorPair yp = vy[i];
+ #pragma unroll(2)
+ for (i = 0; i < nvec; i++) {
+ HVX_Vector y_hf = y[i];
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf);
- HVX_Vector x = vx[i];
- HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
+ rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
+ rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
+ }
- //NOTE: need volatile here to prevent compiler optimization
- // Seem compiler cannot guarantee read-after-write??
- volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
- volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
+ if (nloe) {
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
+ HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
+
+ HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+ HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
- HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo);
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
+ rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)));
+ rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
}
- if (nv1) {
- // HVX_VectorPair yp = vy[i];
+ rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0));
+ rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1));
+ HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4);
- // HVX_Vector x = vx[i];
- // HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
+ hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+}
- // if (nv1 >= 32) {
- // volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
- // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
- // nv1 -= 32;
- // }
+static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const HVX_UVector * restrict x = (const HVX_UVector *) vx;
+ const HVX_UVector * restrict y = (const HVX_UVector *) vy;
- // rsum = hvx_vec_qf32_reduce_sum(rsum);
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
- // if (nv1) {
- // volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
- // HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
- // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
- // }
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
- //process the remainder using scalar loop
- rsum = hvx_vec_qf32_reduce_sum(rsum);
- const __fp16 * restrict sx = (const __fp16 * restrict) x;
- const float * restrict sy = (const float * restrict) y;
+ uint32_t i = 0;
- for (uint32_t i = nv0 * 64; i < n; i++) {
- r_sum_scalar += (float) sx[i] * sy[i];
- }
+ #pragma unroll(4)
+ for (i = 0; i < nvec; i++) {
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+ }
+
+ if (nloe) {
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
+ HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
- // hvx_vec_dump_fp16("X", x);
- // hvx_vec_dump_fp16("Y", y);
- // hvx_vec_dump_fp32("SUM", Q6_Vsf_equals_Vqf32(sum));
- // hvx_vec_dump_fp32("RSUM", Q6_Vsf_equals_Vqf32(rsum));
- } else {
- rsum = hvx_vec_qf32_reduce_sum(rsum);
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
- *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar;
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+ hvx_vec_store_u(&s[0], 4, rsum);
+}
-# ifdef HTP_DEBUG
- {
- float rsum = 0;
- const __fp16 * restrict vx = (const __fp16 * restrict) x;
- const float * restrict vy = (const float * restrict) y;
+static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
+ const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
+ const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
- for (uint32_t i = 0; i < n; i++) {
- rsum += vx[i] * vy[i];
- }
+ uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
+ uint32_t nloe = n % VLEN_FP16; // leftover elements
- float diff = fabs(*s - rsum);
- if (diff > 0.001) {
- FARF(HIGH, "vec-dot-f16-missmatch: %u (%u:%u) expected %.6f got %.6f\n", n, nv0, nv1, rsum, *s);
- // htp_dump_f16("x", vx, n);
- // htp_dump_f32("y", vy, n);
- }
+ const HVX_Vector zero = Q6_V_vsplat_R(0);
+
+ HVX_Vector rsum = Q6_V_vsplat_R(0);
+
+ uint32_t i = 0;
+
+ #pragma unroll(2)
+ for (i = 0; i < nvec; i++) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
-# endif
-}
-#else
-static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
- const uint32_t fk = 64;
- const uint32_t nb = n / fk;
- assert(n % fk == 0);
- assert(nb % 4 == 0);
+ if (nloe) {
+ // Load y (fp32) and convert into fp16
+ HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
+ HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
+ HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
- const uint32_t x_blk_size = 2 * fk; // fp16
- const uint32_t y_blk_size = 4 * fk; // fp32
+ // Load x (fp16)
+ HVX_Vector x_hf = vx[i];
- // Row sum (qf32)
- HVX_Vector rsum0 = Q6_V_vsplat_R(0);
- HVX_Vector rsum1 = Q6_V_vsplat_R(0);
- HVX_Vector rsum2 = Q6_V_vsplat_R(0);
- HVX_Vector rsum3 = Q6_V_vsplat_R(0);
-
- for (uint32_t i = 0; i < nb; i += 4) {
- HVX_Vector_x4 vx = hvx_vec_load_x4_f16(x + (i * x_blk_size));
- HVX_Vector_x4 vy = hvx_vec_load_x4_f32_as_f16(y + (i * y_blk_size));
-
- HVX_VectorPair fa0 = Q6_Wqf32_vmpy_VhfVhf(vx.v[0], vy.v[0]);
- HVX_VectorPair fa1 = Q6_Wqf32_vmpy_VhfVhf(vx.v[1], vy.v[1]);
- HVX_VectorPair fa2 = Q6_Wqf32_vmpy_VhfVhf(vx.v[2], vy.v[2]);
- HVX_VectorPair fa3 = Q6_Wqf32_vmpy_VhfVhf(vx.v[3], vy.v[3]);
-
- rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa0), Q6_V_hi_W(fa0)));
- rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa1), Q6_V_hi_W(fa1)));
- rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa2), Q6_V_hi_W(fa2)));
- rsum3 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum3, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa3), Q6_V_hi_W(fa3)));
+ // Zero-out unused elements
+ // Note that we need to clear both x and y because they may contain NANs
+ HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+ x_hf = Q6_V_vand_QV(bmask, x_hf);
+ y_hf = Q6_V_vand_QV(bmask, y_hf);
+
+ HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
+
+ rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
- // Reduce and convert into fp32
- rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum1);
- rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, rsum3);
- HVX_Vector rsum = hvx_vec_qf32_reduce_sum(Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum2));
- hvx_vec_store_u(s, 4, Q6_Vsf_equals_Vqf32(rsum));
+ rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
+ hvx_vec_store_u(&s[0], 4, rsum);
}
-#endif
-#define htp_matmul_preamble \
+#define htp_matmul_tensors_preamble \
+ struct htp_tensor * restrict src0 = &octx->src0; \
+ struct htp_tensor * restrict src1 = &octx->src1; \
+ struct htp_tensor * restrict src2 = &octx->src2; \
+ struct htp_tensor * restrict dst = &octx->dst; \
+ struct htp_spad * restrict src0_spad = &octx->src0_spad; \
+ struct htp_spad * restrict src1_spad = &octx->src1_spad; \
+ struct htp_spad * restrict dst_spad = &octx->dst_spad; \
+ \
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
const uint32_t ne02 = src0->ne[2]; \
const uint32_t ne12 = src1->ne[2]; \
const uint32_t ne13 = src1->ne[3]; \
\
+ const uint32_t ne20 = src2->ne[0]; \
+ const uint32_t ne21 = src2->ne[1]; \
+ const uint32_t ne22 = src2->ne[2]; \
+ const uint32_t ne23 = src2->ne[3]; \
+ \
const uint32_t ne0 = dst->ne[0]; \
const uint32_t ne1 = dst->ne[1]; \
const uint32_t ne2 = dst->ne[2]; \
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3];
-// q8x4 src1 tensor is already in VTCM spad
-static void matmul(struct htp_matmul_type * mt,
- struct htp_tensor * restrict src0,
- struct htp_tensor * restrict src1,
- struct htp_tensor * restrict dst,
- struct htp_spad * restrict src0_spad,
- struct htp_spad * restrict src1_spad,
- struct htp_spad * restrict dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+#define htp_matmul_preamble \
+ htp_matmul_tensors_preamble; \
+ dma_queue *dma_queue = octx->ctx->dma[ith]; \
+ uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
+
+// *** matmul with support for 4d tensors and full broadcasting
+
+static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
+ htp_matmul_preamble;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ assert(ne12 % ne02 == 0);
+ assert(ne13 % ne03 == 0);
+
+ // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
+ const uint32_t nr0 = ne0;
+
+ // This is the size of the rest of the dimensions of the result
+ const uint32_t nr1 = ne1 * ne2 * ne3;
+
+ // distribute the thread work across the inner or outer loop based on which one is larger
+ uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
+ uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+
+ // The number of elements in each chunk
+ const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
+ const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
+
+ uint32_t current_chunk = ith;
+
+ const uint32_t ith0 = current_chunk % nchunk0;
+ const uint32_t ith1 = current_chunk / nchunk0;
+
+ const uint32_t ir0_start = dr0 * ith0;
+ const uint32_t ir0_end = MIN(ir0_start + dr0, nr0);
+
+ const uint32_t ir1_start = dr1 * ith1;
+ const uint32_t ir1_end = MIN(ir1_start + dr1, nr1);
+
+ // no work for this thread
+ if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
+ return;
+ }
+
+ // block-tiling attempt
+ const uint32_t blck_0 = 64;
+ const uint32_t blck_1 = 64;
+
+ for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
+ for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
+ for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
+ const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1);
+ const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1);
+ const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
+
+ // broadcast src0 into src1
+ const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3);
+ const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2);
+
+ const uint32_t i1 = i11;
+ const uint32_t i2 = i12;
+ const uint32_t i3 = i13;
+
+ const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
+ const uint8_t * restrict src1_col = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
+ float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
+
+ const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
+ for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
+ const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
+ mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col);
+ }
+ }
+ }
+ }
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
+ src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// src1 tensor is already in VTCM spad
+static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
htp_matmul_preamble;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const size_t dst_row_size = nb1;
const size_t src0_row_size = nb01;
- const size_t src1_row_size = q8x4x2_row_size(ne10);
+ const size_t src1_row_size = nb11;
- const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+ const size_t src0_stride = src0_spad->stride;
+ const size_t src1_stride = src1_spad->stride;
// Per-thread VTCM scratchpads for all tensors
// Note that the entire src1 tensor is already in VTCM
#pragma unroll(4)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const int is0 = (ir0 - src0_start_row);
- if (is0 >= HTP_SPAD_SRC0_NROWS) {
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
break;
}
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 2);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
}
// Process src0 rows
#pragma unroll(2)
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
- const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
- mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
+ mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col);
}
// Prefetch next (n + spad_nrows) row
- const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
- const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
if (pr0 < src0_end_row_x2) {
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 2);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
}
}
if (src0_end_row != src0_end_row_x2) {
uint32_t ir0 = src0_end_row_x2;
const int is0 = (ir0 - src0_start_row);
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 1);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
#pragma unroll(2)
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
- const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size);
+ const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col);
}
}
// q8x4x2 src1 tensor is already in VTCM spad
-static void matvec(struct htp_matmul_type * mt,
- struct htp_tensor * restrict src0,
- struct htp_tensor * restrict src1,
- struct htp_tensor * restrict dst,
- struct htp_spad * restrict src0_spad,
- struct htp_spad * restrict src1_spad,
- struct htp_spad * restrict dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
htp_matmul_preamble;
const uint32_t src0_nrows = ne01;
const size_t dst_row_size = nb1;
const size_t src0_row_size = nb01;
- const size_t src1_row_size = q8x4x2_row_size(ne10);
+ const size_t src1_row_size = nb11;
- const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128);
+ const size_t src0_stride = src0_spad->stride;
+ const size_t src1_stride = src1_spad->stride;
// Per-thread VTCM scratchpads for all tensors
// Note that the entire src1 tensor is already in VTCM
#pragma unroll(2)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint32_t is0 = (ir0 - src0_start_row);
- if (is0 >= HTP_SPAD_SRC0_NROWS) {
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
break;
}
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 2);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
}
// Process src0 rows
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
- mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_row_size_padded, src1_col);
+ mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col);
// Prefetch next (n + spad_nrows) row
- const uint32_t pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
- const uint32_t is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+ const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
if (pr0 < src0_end_row_x2) {
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 2);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size),
+ src0_stride, src0_row_size, 2);
}
}
if (src0_end_row != src0_end_row_x2) {
const uint32_t ir0 = src0_end_row_x2;
const uint32_t is0 = (ir0 - src0_start_row);
- dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
- src0_row_size_padded, src0_row_size, 1);
+ dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size),
+ src0_stride, src0_row_size, 1);
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col);
}
uint32_t i2;
};
-// q8x4 src1 tensor is already in VTCM spad
-static void matmul_id(struct htp_matmul_type * mt,
- struct htp_tensor * restrict src0,
- struct htp_tensor * restrict src1,
- struct htp_tensor * restrict ids,
- struct htp_tensor * restrict dst,
- struct htp_spad * restrict src0_spad,
- struct htp_spad * restrict src1_spad,
- struct htp_spad * restrict src2_spad,
- struct htp_spad * restrict dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+// src1 tensor is already in VTCM spad
+static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
htp_matmul_preamble;
+ struct htp_tensor * restrict ids = &octx->src2;
+ struct htp_spad * restrict src2_spad = &octx->src2_spad;
+
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
#pragma unroll(4)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const int is0 = (ir0 - src0_start_row);
- if (is0 >= HTP_SPAD_SRC0_NROWS) {
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
break;
}
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
}
// Prefetch next (n + spad_nrows) row
- const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
- const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
if (pr0 < src0_end_row_x2) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
src0_row_size_padded, src0_row_size, 2);
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-// q8x4 src1 tensor is already in VTCM spad
-static void matvec_id(struct htp_matmul_type * mt,
- struct htp_tensor * restrict src0,
- struct htp_tensor * restrict src1,
- struct htp_tensor * restrict src2,
- struct htp_tensor * restrict dst,
- struct htp_spad * restrict src0_spad,
- struct htp_spad * restrict src1_spad,
- struct htp_spad * restrict src2_spad,
- struct htp_spad * restrict dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
+// src1 tensor is already in VTCM spad
+static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) {
htp_matmul_preamble;
+ struct htp_tensor * restrict ids = &octx->src2;
+ struct htp_spad * restrict src2_spad = &octx->src2_spad;
+
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
#pragma unroll(4)
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const int is0 = (ir0 - src0_start_row);
- if (is0 >= HTP_SPAD_SRC0_NROWS) {
+ if (is0 >= MM_SPAD_SRC0_NROWS) {
break;
}
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size),
mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col);
// Prefetch next (n + spad_nrows) row
- const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS);
- const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS;
+ const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS);
+ const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS;
if (pr0 < src0_end_row_x2) {
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size),
src0_row_size_padded, src0_row_size, 2);
dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
-// *** matmul in fp16
-
-static void matmul_f16_f32(struct htp_tensor * restrict src0,
- struct htp_tensor * restrict src1,
- struct htp_tensor * restrict dst,
- struct htp_spad * restrict src0_spad,
- struct htp_spad * restrict src1_spad,
- struct htp_spad * restrict dst_spad,
- uint32_t nth,
- uint32_t ith,
- uint32_t src0_nrows_per_thread,
- dma_queue * dma_queue) {
- htp_matmul_preamble;
-
- uint64_t t1, t2;
- t1 = HAP_perf_get_qtimer_count();
-
- assert(ne12 % ne02 == 0);
- assert(ne13 % ne03 == 0);
-
- // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
- const uint32_t nr0 = ne0;
-
- // This is the size of the rest of the dimensions of the result
- const uint32_t nr1 = ne1 * ne2 * ne3;
-
- // distribute the thread work across the inner or outer loop based on which one is larger
- uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
- uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
-
- // The number of elements in each chunk
- const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
- const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
-
- uint32_t current_chunk = ith;
-
- const uint32_t ith0 = current_chunk % nchunk0;
- const uint32_t ith1 = current_chunk / nchunk0;
-
- const uint32_t ir0_start = dr0 * ith0;
- const uint32_t ir0_end = MIN(ir0_start + dr0, nr0);
-
- const uint32_t ir1_start = dr1 * ith1;
- const uint32_t ir1_end = MIN(ir1_start + dr1, nr1);
-
- // broadcast factors
- const uint32_t r2 = ne12 / ne02;
- const uint32_t r3 = ne13 / ne03;
-
- // no work for this thread
- if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
- return;
- }
-
- // block-tiling attempt
- const uint32_t blck_0 = 64;
- const uint32_t blck_1 = 64;
-
- __attribute__((aligned(128))) float tmp[64];
-
- for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
- for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
- for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
- const uint32_t i13 = (ir1 / (ne12 * ne1));
- const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
- const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
-
- // broadcast src0 into src1
- const uint32_t i03 = i13 / r3;
- const uint32_t i02 = i12 / r2;
-
- const uint32_t i1 = i11;
- const uint32_t i2 = i12;
- const uint32_t i3 = i13;
-
- const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
- const uint8_t * restrict src1_col =
- (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
- float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
-
- const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
- for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
- // Use nb01 stride for non-contiguous src0 support
- const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
- vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col);
- }
-
- hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0);
- }
- }
- }
-
- t2 = HAP_perf_get_qtimer_count();
-
- FARF(HIGH, "matmul-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
- src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0],
- src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
- (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
-}
-
// *** dynamic quant
static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
for (uint32_t i = 0; i < nb; i++) {
#if FP32_QUANTIZE_GROUP_SIZE == 32
- quantize_block_fp32_q8x1(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
- t_d + (i * 2 + 0) * dblk_size / 2);
- quantize_block_fp32_q8x1(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
- t_d + (i * 2 + 1) * dblk_size / 2);
+ quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+ quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
#elif FP32_QUANTIZE_GROUP_SIZE == 64
- quantize_block_fp32_q8x2(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
- t_d + (i * 2 + 0) * dblk_size / 2);
- quantize_block_fp32_q8x2(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
- t_d + (i * 2 + 1) * dblk_size / 2);
+ quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+ quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
#elif FP32_QUANTIZE_GROUP_SIZE == 128
- quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
- t_d + (i * 2 + 0) * dblk_size / 2);
- quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
- t_d + (i * 2 + 1) * dblk_size / 2);
+ quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2);
+ quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2);
#else
#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
#endif
ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
+static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
+ uint32_t nrows_per_thread, uint32_t dst_stride) {
+
+ uint64_t t1 = HAP_perf_get_qtimer_count();
+
+ const uint32_t ne0 = src->ne[0];
+ const uint32_t ne1 = src->ne[1];
+ const uint32_t ne2 = src->ne[2];
+ const uint32_t ne3 = src->ne[3];
+
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
+
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
+
+ const size_t src_row_size = ne0 * sizeof(float);
+ const size_t src_stride = src->nb[1];
+
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
+
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
+ htp_l2fetch(src_data, 2, src_row_size, src_stride);
+ hvx_copy_fp16_fp32_au(dst_data, src_data, ne0);
+
+ dst_data += dst_stride;
+ src_data += src_stride;
+ }
+
+ uint64_t t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
+ ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+// TODO just a plain copy that should be done via the DMA during the Op setup
+static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith,
+ uint32_t nrows_per_thread, uint32_t dst_stride) {
+
+ uint64_t t1 = HAP_perf_get_qtimer_count();
+
+ const uint32_t ne0 = src->ne[0];
+ const uint32_t ne1 = src->ne[1];
+ const uint32_t ne2 = src->ne[2];
+ const uint32_t ne3 = src->ne[3];
+
+ const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows
+
+ const uint32_t ir_first = nrows_per_thread * ith; // first row
+ const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row
+
+ const size_t src_row_size = ne0 * sizeof(float);
+ const size_t src_stride = src->nb[1];
+
+ uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first);
+ uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first);
+
+ for (uint32_t i = ir_first; i < ir_last; ++i) {
+ htp_l2fetch(src_data, 2, src_row_size, src_stride);
+ hvx_copy_fp16_au(dst_data, src_data, ne0);
+
+ dst_data += dst_stride;
+ src_data += src_stride;
+ }
+
+ uint64_t t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first,
+ ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread);
}
-// ** matmul callbacks for worker_pool
+static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+ quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
+}
+
+static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+ quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride);
+}
+
+// ** matmul/matvec callbacks for worker_pool
-static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
- matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_2d(&mt, octx, n, i);
}
-static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
- matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_2d(&mt, octx, n, i);
}
-static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
- matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_2d(&mt, octx, n, i);
}
-static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
- matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_2d(&mt, octx, n, i);
}
-static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
- matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_2d(&mt, octx, n, i);
}
-static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
+static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
- matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_2d(&mt, octx, n, i);
}
-static void htp_matmul_f16_f32(unsigned int n, unsigned int i, void * data) {
+static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
- matmul_f16_f32(&octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i,
- octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+
+ struct htp_matmul_type mt;
+ mt.type = "f16-f16";
+ mt.vec_dot = vec_dot_f16_f16_aa;
+ mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
+
+ matvec_2d(&mt, octx, n, i);
+}
+
+static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+
+ struct htp_matmul_type mt;
+ mt.type = "f16-f16";
+ mt.vec_dot = vec_dot_f16_f16_aa;
+ mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2;
+
+ matmul_2d(&mt, octx, n, i);
+}
+
+static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+
+ struct htp_matmul_type mt;
+ mt.type = "f16-f32";
+ mt.vec_dot = vec_dot_f16_f32_uu;
+
+ matmul_4d(&mt, octx, n, i);
+}
+
+static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = data;
+
+ struct htp_matmul_type mt;
+ mt.type = "f16-f16";
+ mt.vec_dot = vec_dot_f16_f16_uu;
+
+ matmul_4d(&mt, octx, n, i);
}
// ** matmul-id callbacks for worker_pool
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
- matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_id(&mt, octx, n, i);
}
static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
- matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_id(&mt, octx, n, i);
}
static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
- matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_id(&mt, octx, n, i);
}
static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
- matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_id(&mt, octx, n, i);
}
static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
- matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matvec_id(&mt, octx, n, i);
}
static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) {
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
- matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad,
- &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
+ matmul_id(&mt, octx, n, i);
}
// ** main matmul entry point
-int op_matmul(struct htp_ops_context * octx) {
- const struct htp_tensor * src0 = &octx->src0;
- const struct htp_tensor * src1 = &octx->src1;
- struct htp_tensor * dst = &octx->dst;
+static inline bool htp_is_permuted(const struct htp_tensor * t) {
+ return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3];
+}
- htp_matmul_preamble;
+int op_matmul(struct htp_ops_context * octx) {
+ htp_matmul_tensors_preamble;
const char * op_type;
op_type = "q4x4x2-fp32";
quant_job_func = htp_quantize_fp32_q8x4x2;
if (src1_nrows > 1) {
- matmul_job_func = htp_matmul_q4x4x2_q8x4x2;
+ matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2;
} else {
- matmul_job_func = htp_matvec_q4x4x2_q8x4x2;
+ matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2;
}
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
op_type = "q8x4x2-fp32";
quant_job_func = htp_quantize_fp32_q8x4x2;
if (src1_nrows > 1) {
- matmul_job_func = htp_matmul_q8x4x2_q8x4x2;
+ matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2;
} else {
- matmul_job_func = htp_matvec_q8x4x2_q8x4x2;
+ matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2;
}
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
op_type = "mxfp4x4x2-f32";
quant_job_func = htp_quantize_fp32_q8x4x2;
if (src1_nrows > 1) {
- matmul_job_func = htp_matmul_mxfp4x4x2_q8x4x2;
+ matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2;
} else {
- matmul_job_func = htp_matvec_mxfp4x4x2_q8x4x2;
+ matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2;
}
src1_row_size = q8x4x2_row_size(ne10); // row size post quantization
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
break;
case HTP_TYPE_F16:
- op_type = "f16-f32";
- quant_job_func = NULL; // htp_quantize_f32_f16;
- matmul_job_func = htp_matmul_f16_f32;
-
- // For all tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size, 256);
- octx->src1_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC1_NROWS * src1_row_size, 256);
-
- octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
- octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
- octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
-
- need_quant = false;
+ {
+ // Try optimized f16-f16 path first (src1 in VTCM)
+ const size_t f16_src1_row_size = htp_round_up(ne10 * 2, 128);
+ const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256);
+ const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads;
+ const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads;
+
+ const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size;
+
+ // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting).
+ // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul.
+ const bool is_batched = (ne02 > 1) || (ne03 > 1);
+ const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1);
+
+ if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) {
+ // Optimized path
+ op_type = "f16-f16";
+ quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16;
+ if (src1_nrows > 1) {
+ matmul_job_func = htp_matmul_2d_f16_f16;
+ } else {
+ matmul_job_func = htp_matvec_2d_f16_f16;
+ }
+
+ src1_row_size = f16_src1_row_size; // row size post quantization
+
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
+
+ octx->src1_spad.size = octx->src1_spad.size_per_thread;
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
+ } else {
+ // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required
+ quant_job_func = NULL;
+ if (src1->type == HTP_TYPE_F32) {
+ op_type = "f16-f32";
+ matmul_job_func = htp_matmul_4d_f16_f32;
+ } else {
+ op_type = "f16-f16";
+ matmul_job_func = htp_matmul_4d_f16_f16;
+ }
+
+ src1_row_size = nb11; // original row size in DDR
+
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256);
+ octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256);
+
+ octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
+ octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
+ octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
+
+ // Init fastdiv for matmul_4d (supports broadcasting)
+ octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]);
+ octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]);
+ octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]);
+ octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]);
+
+ need_quant = false;
+ }
+ }
break;
default:
octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads;
octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even
+ octx->src0_spad.stride = src0_row_size_padded;
+ octx->src1_spad.stride = src1_row_size;
+
if (need_quant) {
// Run quant jobs
const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads);
// ** main matmul-id entry point
int op_matmul_id(struct htp_ops_context * octx) {
- const struct htp_tensor * src0 = &octx->src0;
- const struct htp_tensor * src1 = &octx->src1;
- const struct htp_tensor * ids = &octx->src2;
- struct htp_tensor * dst = &octx->dst;
+ htp_matmul_tensors_preamble;
- htp_matmul_preamble;
+ struct htp_tensor * restrict ids = &octx->src2;
const char * op_type;
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);
// Entire src1 tensor is placed into the VTCM
// For other tensors we allocate N rows per thread, padded to HVX vector size
- octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256);
- octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
+ octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256);
+ octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256);
octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256);
octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256);