]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
hexagon: improve RMS_NORM and DIV accuracy (#21251)
authorAparna M P <redacted>
Wed, 1 Apr 2026 15:43:08 +0000 (21:13 +0530)
committerGitHub <redacted>
Wed, 1 Apr 2026 15:43:08 +0000 (08:43 -0700)
* hexagon-rms_norm: fix RMS_NORM for non-aligned tensor sizes

Co-authored-by: Krishna Sridhar <redacted>
* hexagon-div: perform DIV in fp16 domain for lower dsp archs

---------

Co-authored-by: Krishna Sridhar <redacted>
ggml/src/ggml-hexagon/htp/hvx-div.h
ggml/src/ggml-hexagon/htp/unary-ops.c

index 05cefea039f6fd91e05265c2b77525ed4a43732b..53ee304e749bd743510f60e3e84cff3d808b9e9f 100644 (file)
 
 #if __HVX_ARCH__ < 79
 #define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
+#define HVX_OP_MUL_F16(a, b) Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b))
 #else
 #define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
+#define HVX_OP_MUL_F16(a, b) Q6_Vhf_vmpy_VhfVhf(a, b)
 #endif
 
 // Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.
@@ -43,46 +45,67 @@ static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX
     return res;
 }
 
-#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store)                     \
-    do {                                                                                \
-        dst_type * restrict vdst = (dst_type *) dst;                                    \
-        src_type * restrict vsrc = (src_type *) src;                                    \
-        HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00);                                     \
-                                                                                        \
-        const uint32_t nvec = n / VLEN_FP16;                                            \
-        const uint32_t nloe = n % VLEN_FP16;                                            \
-                                                                                        \
-        uint32_t i = 0;                                                                 \
-                                                                                        \
-        _Pragma("unroll(4)")                                                            \
-        for (; i < nvec; i++) {                                                         \
-            HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
-            vdst[i] = res;                                                              \
-        }                                                                               \
-        if (nloe) {                                                                     \
-            HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
-            vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res);                      \
-        }                                                                               \
+// Variant for <v79: Use pre-computed f16 reciprocal constant
+static inline HVX_Vector hvx_div_mul_f16_const_using_f16(HVX_Vector vec1_hf, HVX_Vector const_inv_hf) {
+    // Multiply by pre-computed f16 reciprocal constant
+    return HVX_OP_MUL_F16(vec1_hf, const_inv_hf);
+}
+
+#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store)                                    \
+    do {                                                                                               \
+        dst_type * restrict vdst = (dst_type *) dst;                                                   \
+        src_type * restrict vsrc = (src_type *) src;                                                   \
+                                                                                                       \
+        HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00);                                                    \
+                                                                                                       \
+        const uint32_t nvec = n / VLEN_FP16;                                                           \
+        const uint32_t nloe = n % VLEN_FP16;                                                           \
+                                                                                                       \
+        uint32_t i = 0;                                                                                \
+                                                                                                       \
+        _Pragma("unroll(4)")                                                                           \
+        for (; i < nvec; i++) {                                                                        \
+            HVX_Vector res;                                                                            \
+            if (__HVX_ARCH__ < 79) {                                                                   \
+                res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16);                           \
+            } else {                                                                                   \
+                res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one);                   \
+            }                                                                                          \
+            vdst[i] = res;                                                                             \
+        }                                                                                              \
+        if (nloe) {                                                                                    \
+            HVX_Vector res;                                                                            \
+            if (__HVX_ARCH__ < 79) {                                                                   \
+                res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16);                           \
+            } else {                                                                                   \
+                res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one);                   \
+            }                                                                                          \
+            vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res);                                     \
+        }                                                                                              \
     } while(0)
 
 static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
     const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
     assert((uintptr_t) dst % 128 == 0);
     assert((uintptr_t) src % 128 == 0);
     hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
 }
 static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
     const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
     assert((uintptr_t) dst % 128 == 0);
     hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
 }
 static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
     const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
     assert((uintptr_t) src % 128 == 0);
     hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
 }
 static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
     const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
+    const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
     hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
 }
 
@@ -128,13 +151,25 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
     return recip;
 }
 
+// Hybrid approach: f16 reciprocal for <v79, f32 precision for >=v79
+static inline HVX_Vector hvx_vec_hybrid_div_f16(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector f16_nan_inf_mask, HVX_Vector vec_hf_one_1_0) {
+#if __HVX_ARCH__ < 79
+    // For older architectures, use f16 reciprocal to avoid NaN/-inf issues
+    HVX_Vector vec2_inv = hvx_vec_inverse_f16_guard(vec2, f16_nan_inf_mask);
+    return HVX_OP_MUL_F16(vec1, vec2_inv);
+#else
+    return hvx_vec_div_f16_using_f32(vec1, vec2, f32_nan_inf_mask, vec_hf_one_1_0);
+#endif
+}
+
 #define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store)                  \
     do {                                                                                  \
         dst_type * restrict vdst = (dst_type *) dst;                                      \
         src0_type * restrict vsrc0 = (src0_type *) src0;                                  \
         src1_type * restrict vsrc1 = (src1_type *) src1;                                  \
                                                                                           \
-        const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000);                        \
+        const HVX_Vector f32_nan_inf_mask = Q6_V_vsplat_R(0x7f800000);                    \
+        const HVX_Vector f16_nan_inf_mask = Q6_Vh_vsplat_R(0x7c00);                       \
         const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00);                                 \
                                                                                           \
         const uint32_t nvec = n / VLEN_FP16;                                              \
@@ -144,11 +179,15 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
                                                                                           \
         _Pragma("unroll(4)")                                                              \
         for (; i < nvec; i++) {                                                           \
-            HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
+            HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i],                   \
+                                                    f32_nan_inf_mask, f16_nan_inf_mask,   \
+                                                    hf_one);                              \
             vdst[i] = res;                                                                \
         }                                                                                 \
         if (nloe) {                                                                       \
-            HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
+            HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i],                   \
+                                                    f32_nan_inf_mask, f16_nan_inf_mask,   \
+                                                    hf_one);                              \
             vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res);                        \
         }                                                                                 \
     } while(0)
@@ -247,5 +286,6 @@ HVX_DIV_DISPATCHER(hvx_div_f32)
 HVX_DIV_DISPATCHER(hvx_div_f16)
 
 #undef HVX_OP_MUL_F32
+#undef HVX_OP_MUL_F16
 
 #endif // HVX_DIV_H
index 3d0928d4dce164d802593c6ca6a411abffac43c1..13d28317d5c970ce704baca316320d1f54935201 100644 (file)
@@ -67,34 +67,61 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
                                   uint8_t * restrict pad,
                                   const int num_elems,
                                   float     epsilon) {
+    (void)pad;
+
     const HVX_Vector * restrict v_src = (HVX_Vector *) src;
     HVX_Vector * restrict v_dst       = (HVX_Vector *) dst;
 
-    HVX_Vector sum_v     = Q6_V_vsplat_R(0x00000000);
+    const int nvec = num_elems / VLEN_FP32;    // number of full vectors
+    const int nloe = num_elems % VLEN_FP32;    // leftover elements
+
+    // Compute sum of squares for full vectors
+    HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
     HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
 
-    int step_of_1 = num_elems >> 5;
     #pragma unroll(4)
-    for (int i = 0; i < step_of_1; i++) {
+    for (int i = 0; i < nvec; i++) {
         HVX_Vector v1 = v_src[i];
         HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
-        sum_v         = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
+        sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
+    }
+
+    // Handle tail elements using vectorized ops with masking
+    if (nloe > 0) {
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
+        HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
+        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
+        sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
     }
 
-    sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
+    // Reduce HVX sum
+    sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
 
     HVX_Vector t_v            = hvx_vec_splat_f32((float) num_elems);
     HVX_Vector denom_v        = hvx_vec_inverse_f32(t_v);
     HVX_Vector mean_v         = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
     HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
 
+    // Scale full vectors
     HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
 
     #pragma unroll(4)
-    for (int i = 0; i < step_of_1; i++) {
+    for (int i = 0; i < nvec; i++) {
         HVX_Vector v1 = v_src[i];
         HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
-        v_dst[i]      = Q6_Vsf_equals_Vqf32(v2);
+        v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
+    }
+
+    // Handle tail elements using vectorized ops with masking
+    if (nloe > 0) {
+
+        HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
+        HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
+        HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
+        HVX_Vector result = Q6_Vsf_equals_Vqf32(v2);
+
+        // Store with masking to avoid overwriting memory beyond the tensor
+        hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
     }
 }