]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-hexagon: flash-attention and reduce-sum optimizations (#19141)
authornullname <redacted>
Sat, 31 Jan 2026 05:14:20 +0000 (13:14 +0800)
committerGitHub <redacted>
Sat, 31 Jan 2026 05:14:20 +0000 (21:14 -0800)
* wip

* ggml-hexagon: add vectorized dot product function for FP32 and FP16 accumulation

* ggml-hexagon: optimize dot product functions for FP16 and FP32 with new vectorized implementations

* wip

* ggml-hexagon: optimize hvx_vec_dump_f32_n and hvx_vec_reduce_sum_qf32x2 functions for improved performance

* ggml-hexagon: refactor dot product functions to use a common loading function for improved readability

* optimize vector dot product functions to use unified reduction for improved performance

* wip

* ggml-hexagon: add vectorized dot product function for FP32 and FP16 accumulation

* ggml-hexagon: optimize dot product functions for FP16 and FP32 with new vectorized implementations

* wip

* ggml-hexagon: optimize hvx_vec_dump_f32_n and hvx_vec_reduce_sum_qf32x2 functions for improved performance

* ggml-hexagon: refactor dot product functions to use a common loading function for improved readability

* optimize vector dot product functions to use unified reduction for improved performance

* hexagon: optimize reduce-sum for v75+

* hexagon: always keep row_sums in sf/fp32

* ggml-hexagon: enhance directory checks for HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT

* fix compiling error after rebase

---------

Co-authored-by: Max Krasnyansky <redacted>
ggml/src/ggml-hexagon/CMakeLists.txt
ggml/src/ggml-hexagon/htp/flash-attn-ops.c
ggml/src/ggml-hexagon/htp/hvx-dump.h
ggml/src/ggml-hexagon/htp/hvx-reduce.h
ggml/src/ggml-hexagon/htp/matmul-ops.c
ggml/src/ggml-hexagon/htp/softmax-ops.c
ggml/src/ggml-hexagon/htp/unary-ops.c

index 2b69197017fcaddd17e496268e587537eab8305b..f3a583543c6392bde9dc6c64fefd6382d10075a1 100644 (file)
@@ -1,8 +1,20 @@
 file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}"   HEXAGON_SDK_ROOT)
 file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
 
-if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}" OR NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
-    message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
+if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}")
+    message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.")
+endif()
+
+if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
+    message("Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json")
+    file(READ "${HEXAGON_SDK_ROOT}/hexagon_sdk.json" HEXAGON_SDK_CONFIG_PATH)
+    string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} "root" "tools" "info" 0 "path")
+    message("Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}")
+    set(HEXAGON_TOOLS_ROOT "${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}")
+    file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
+    if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
+        message(FATAL_ERROR "Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
+    endif()
 endif()
 
 message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels")
index c7cb2a4e0bc594a07bccf5fd07ce5290fbe1fe25..c1846374437abb2f97c31d0144d282967150ff9f 100644 (file)
 #include "htp-msg.h"
 #include "htp-ops.h"
 
+static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
+    HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero);  // 32 elements
+    HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero);  // 32 elements
+    return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+}
+
 // Dot product of FP32 and FP16 vectors, accumulating to float
 static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
     const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
@@ -33,23 +39,19 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict
     #pragma unroll(4)
     for (i = 0; i < nvec; i++) {
         // Load y (fp32) and convert into fp16
-        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements
-        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements
-        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+        HVX_Vector y_hf  = hvx_load_f32_to_f16(&vy[i*2], zero);
 
         // Load x (fp16)
         HVX_Vector x_hf  = vx[i];
 
         HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
 
-        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+        rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
     }
 
     if (nloe) {
         // Load y (fp32) and convert into fp16
-        HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero);  // 32 elements
-        HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero);  // 32 elements
-        HVX_Vector y_hf  = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
+        HVX_Vector y_hf  = hvx_load_f32_to_f16(&vy[i*2], zero);
 
         // Load x (fp16)
         HVX_Vector x_hf  = vx[i];
@@ -62,13 +64,72 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict
 
         HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
 
-        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
+        rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
     }
 
-    rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
-    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
+    rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
+    hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
+}
+
+// Dot product of FP32 and FP16 vectors, accumulating to float
+static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
+                                          const void * restrict y,
+                                          const void * restrict x0,
+                                          const void * restrict x1,
+                                          unsigned int n,
+                                          float        s) {
+    const HVX_Vector * restrict vy  = (const HVX_Vector * restrict) y;   // fp32
+    const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0;  // fp16
+    const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1;  // fp16
+
+    uint32_t nvec = n / VLEN_FP16;                                       // num full fp16 hvx vectors
+    uint32_t nloe = n % VLEN_FP16;                                       // leftover elements
+
+    const HVX_Vector zero  = Q6_V_vsplat_R(0);
+    HVX_Vector       rsum0 = Q6_V_vsplat_R(0);
+    HVX_Vector       rsum1 = Q6_V_vsplat_R(0);
+
+    uint32_t i = 0;
 
-    hvx_vec_store_u(r, 4, rsum);
+    #pragma unroll(2)
+    for (i = 0; i < nvec; i++) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector y_hf  = hvx_load_f32_to_f16(&vy[i*2], zero);
+        // Load x (fp16)
+        HVX_Vector x0_hf = vx0[i];
+        HVX_Vector x1_hf = vx1[i];
+
+        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+        rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+        rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+    }
+
+    if (nloe) {
+        // Load y (fp32) and convert into fp16
+        HVX_Vector y_hf  = hvx_load_f32_to_f16(&vy[i*2], zero);
+
+        // Load x (fp16)
+        HVX_Vector x0_hf = vx0[i];
+        HVX_Vector x1_hf = vx1[i];
+
+        // Zero-out unused elements
+        // Note that we need to clear both x and y because they may contain NANs
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+        x0_hf                = Q6_V_vand_QV(bmask, x0_hf);
+        x1_hf                = Q6_V_vand_QV(bmask, x1_hf);
+        y_hf                 = Q6_V_vand_QV(bmask, y_hf);
+
+        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+        rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+        rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+    }
+
+    HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
+    hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
 }
 
 // Dot product of two F16 vectors, accumulating to float
@@ -91,7 +152,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
 
         HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
 
-        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
+        rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
     }
 
     if (nloe) {
@@ -103,12 +164,62 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
 
         HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
 
-        rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
+        rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
+    }
+
+    rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
+    hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
+}
+
+static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
+                                          const void * restrict y,
+                                          const void * restrict x0,
+                                          const void * restrict x1,
+                                          unsigned int n,
+                                          float        s) {
+    const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0;  // fp16
+    const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1;  // fp16
+    const HVX_Vector * restrict vy  = (const HVX_Vector * restrict) y;   // fp16
+
+    uint32_t nvec = n / VLEN_FP16;                                       // num full fp16 hvx vectors
+    uint32_t nloe = n % VLEN_FP16;                                       // leftover elements
+
+    const HVX_Vector zero  = Q6_V_vsplat_R(0);
+    HVX_Vector       rsum0 = Q6_V_vsplat_R(0);
+    HVX_Vector       rsum1 = Q6_V_vsplat_R(0);
+
+    uint32_t i = 0;
+
+    #pragma unroll(4)
+    for (i = 0; i < nvec; i++) {
+        HVX_Vector y_hf  = vy[i];
+        HVX_Vector x0_hf = vx0[i];
+        HVX_Vector x1_hf = vx1[i];
+
+        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+        rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+        rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
+    }
+
+    if (nloe) {
+        HVX_Vector y_hf = vy[i];
+
+        // Load x (fp16) and zero-out unused elements
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
+        HVX_Vector     x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
+        HVX_Vector     x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
+
+        HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
+        HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
+
+        rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
+        rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
     }
 
-    rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s));
-    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
-    hvx_vec_store_u(r, 4, rsum);
+    HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
+    hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
 }
 
 // MAD: y (F32) += x (F16) * s (float)
@@ -317,20 +428,22 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
             // Inner loop processing the block from VTCM
             uint32_t ic = 0;
 
+            const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
+
             // Process in blocks of 32 (VLEN_FP32)
-            static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
+            static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
             HVX_Vector_x4 scores_x4;
             HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
             for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
                 // 1. Compute scores
-                float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE];
-                for (int j = 0; j < VLEN_FP32; ++j) {
+                float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
+                for (int j = 0; j < VLEN_FP32; j += 2) {
                     const uint32_t cur_ic = ic + j;
                     const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
-                    if (q->type == HTP_TYPE_F32) {
-                        hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
+                    if (is_q_fp32) {
+                        hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
                     } else {
-                        hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
+                        hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
                     }
                 }
 
@@ -403,7 +516,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
                 float s_val;
                 const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
 
-                if (q->type == HTP_TYPE_F32) {
+                if (is_q_fp32) {
                     hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
                 } else {
                     hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
index e882227893eb03114d9e063c2cd1dd2a085e4150..85201fc345326ece2310f7ad3f5c7e9a459d5a13 100644 (file)
@@ -28,19 +28,16 @@ static void hvx_vec_dump_f16(char * pref, HVX_Vector v) {
 }
 
 static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) {
-    union {
-        HVX_Vector v;
-        float      d[32];
-    } u = { .v = v };
+    HVX_VectorAlias u = { .v = v };
 
     const uint32_t n0 = n / 16;
     const uint32_t n1 = n % 16;
     int            i  = 0;
     for (; i < n0; i++) {
-        hex_dump_f32_line(pref, u.d + (16 * i), 16);
+        hex_dump_f32_line(pref, u.fp32 + (16 * i), 16);
     }
     if (n1) {
-        hex_dump_f32_line(pref, u.d + (16 * i), n1);
+        hex_dump_f32_line(pref, u.fp32 + (16 * i), n1);
     }
 }
 
index 8845fe73ea175c527e9d23fcf8287a22abf46e30..1ca7c05d983edd14ba84287195b6a3a1ebb16bfc 100644 (file)
@@ -44,6 +44,45 @@ static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {
     return hvx_vec_reduce_sum_n_qf32(in, 32);
 }
 
+#if __HVX_ARCH__ > 75
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
+    HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
+    HVX_Vector  sum_sf  = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
+
+    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));
+    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));
+    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));
+    sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16));
+    return sum_sf;
+}
+
+static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
+    unsigned int total = n * 4;  // total vec nbytes
+    unsigned int width = 4;      // fp32 nbytes
+
+    HVX_Vector sum = in, sum_t;
+    while (width < total) {
+        sum_t = Q6_V_vror_VR(sum, width);       // rotate right
+        sum   = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum
+        width = width << 1;
+    }
+    return sum;
+}
+
+#else
+
+static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
+    HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
+    HVX_Vector  sum_qf  = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
+
+    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));
+    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));
+    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));
+    sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16));
+    return Q6_Vsf_equals_Vqf32(sum_qf);
+}
+
 static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
     unsigned int total = n * 4;  // total vec nbytes
     unsigned int width = 4;      // fp32 nbytes
@@ -57,6 +96,8 @@ static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n)
     return sum;
 }
 
+#endif
+
 static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {
     return hvx_vec_reduce_sum_n_f32(in, 32);
 }
index 1603ff2b3b6a0dd256ad993248049b4f088db605..d251eeed33a97af93d5299935f0b2b87b4681398 100644 (file)
@@ -11,6 +11,7 @@
 
 #include "hex-dma.h"
 #include "hvx-utils.h"
+#include "hvx-dump.h"
 
 #define GGML_COMMON_DECL_C
 #include "ggml-common.h"
@@ -320,7 +321,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
     const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
     const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
 
-    // Row sum (qf32)
+    // Row sum (sf)
     HVX_Vector r0_sum = Q6_V_vsplat_R(0);
 
     // Multiply and accumulate into int32.
@@ -344,7 +345,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
     // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
@@ -362,14 +363,14 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
         // Zero out unused scales
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
-    // Reduce and convert into fp32
-    r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
+    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 
     hvx_vec_store_u(&s[0], 4, r0_sum);
 }
@@ -402,7 +403,7 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
     const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);                                     // quants first
     const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);                           // then scales
 
-    // Row sum (qf32)
+    // Row sum (sf)
     HVX_Vector r0_sum = Q6_V_vsplat_R(0);
     HVX_Vector r1_sum = Q6_V_vsplat_R(0);
 
@@ -432,8 +433,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
     // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
@@ -456,20 +457,18 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
         r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
+        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
-    // Convert into fp32 and reduce
-    r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
-    r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
-    HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
-
-    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+    hvx_vec_store_u(&s[0], 8, rsum);
 }
 
 static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -493,7 +492,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
     const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
     const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
 
-    // Row sum (qf32)
+    // Row sum (sf)
     HVX_Vector r0_sum = Q6_V_vsplat_R(0);
 
     // Multiply and accumulate into int32.
@@ -517,7 +516,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
     // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
@@ -535,14 +534,14 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
         // Zero out unused scales
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
-    // Reduce and convert into fp32
-    r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
+    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 
     hvx_vec_store_u(&s[0], 4, r0_sum);
 }
@@ -605,8 +604,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
     // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
@@ -629,20 +628,18 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
         r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
+        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
-    // Convert into fp32 and reduce
-    r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
-    r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
-    HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
-
-    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+    hvx_vec_store_u(&s[0], 8, rsum);
 }
 
 static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
@@ -669,7 +666,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
     const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);               // quants first
     const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);     // then scales
 
-    // Row sum (qf32)
+    // Row sum (sf)
     HVX_Vector r0_sum = Q6_V_vsplat_R(0);
 
     // Multiply and accumulate into int32.
@@ -708,7 +705,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
     // Process leftovers
@@ -741,14 +738,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
         // Zero-out unused scales
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
     }
 
-    // Reduce and convert into fp32
-    r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
+    r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
 
     hvx_vec_store_u(&s[0], 4, r0_sum);
 }
@@ -781,13 +778,13 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
     const uint8_t * restrict y_q = ((const uint8_t *) vy + 0);                                     // quants first
     const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size);                           // then scales
 
-    // Row sum (qf32)
+    // Row sum (sf)
     HVX_Vector r0_sum = Q6_V_vsplat_R(0);
     HVX_Vector r1_sum = Q6_V_vsplat_R(0);
 
     // Multiply and accumulate into int32.
     // Compute combined scale (fp32).
-    // Apply scale to acc and accumulate into the row sum (qf32).
+    // Apply scale to acc and accumulate into the row sum (f32).
 
     const uint32_t nb   = n / qk;  // num full blocks
     int32_t        nloe = n % qk;  // num leftover elemements (must be signed)
@@ -829,8 +826,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
     // Process leftovers
@@ -867,24 +864,22 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
         HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
         HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
 
-        // Zero-out unused scales
+        // Zero-out unused values
         HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
         r0_dd                = Q6_V_vand_QV(bmask, r0_dd);
         r1_dd                = Q6_V_vand_QV(bmask, r1_dd);
+        r0_ia                = Q6_V_vand_QV(bmask, r0_ia);
+        r1_ia                = Q6_V_vand_QV(bmask, r1_ia);
 
         HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
         HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
 
-        r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa);
-        r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa);
+        r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
+        r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
     }
 
-    // Convert into fp32 and reduce
-    r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
-    r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
-    HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
-
-    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
+    hvx_vec_store_u(&s[0], 8, rsum);
 }
 
 static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -913,7 +908,7 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res
         rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
     }
 
-    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
+    rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
     hvx_vec_store_u(&s[0], 4, rsum);
 }
 
@@ -957,11 +952,8 @@ static void vec_dot_f16_f16_aa_rx2(const int n,
         rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
     }
 
-    rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum0));
-    rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum1));
-    HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4);
-
-    hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
+    HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1));
+    hvx_vec_store_u(&s[0], 8, rsum);
 }
 
 static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@@ -990,7 +982,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res
         rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
     }
 
-    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
+    rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
     hvx_vec_store_u(&s[0], 4, rsum);
 }
 
@@ -1042,7 +1034,8 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res
         rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf),  Q6_V_hi_W(xy_qf)));
     }
 
-    rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum));
+    // Convert into fp32 and reduce
+    rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
     hvx_vec_store_u(&s[0], 4, rsum);
 }
 
index 1b6b2eba4ae34b5c6c4442bd32d5f6415ce9be92..e91a16d947f56f0a4db7653388a4346970fc75b1 100644 (file)
@@ -154,8 +154,8 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
         v_pad[i] = v3;
     }
 
-    v       = hvx_vec_reduce_sum_qf32(sum_vec);
-    sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v));
+    v       = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
+    sum_vec = hvx_vec_repl4(v);
 
     HVX_VectorPred pos_sum   = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
     HVX_Vector     v4        = hvx_vec_inverse_f32(sum_vec);
index be8be8c4e6429ede055ca9a447cc32df79305194..1a27cb6e63e533bb7580233d2807bb5afcdd03f5 100644 (file)
@@ -57,8 +57,8 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
         sum_v         = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
     }
 
-    HVX_Vector reduced_sum = hvx_vec_reduce_sum_qf32(sum_v);
-    sum_v                  = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum));
+    HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
+    sum_v                  = hvx_vec_repl4(reduced_sum);
 
     HVX_Vector t_v            = hvx_vec_splat_f32((float) num_elems);
     HVX_Vector denom_v        = hvx_vec_inverse_f32(t_v);