]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : optimize and build warning fix for LoongArch (llama/11709)
authorJinyang He <redacted>
Fri, 7 Feb 2025 07:38:31 +0000 (15:38 +0800)
committerGeorgi Gerganov <redacted>
Thu, 27 Feb 2025 06:55:36 +0000 (08:55 +0200)
* ggml : optimize convert f32<->f16 for loongarch_asx

* ggml : optimize loongarch_asx extend i16,i8,u8 to i32,i16

* ggml : Fix warnings when run cpu CI locally on LoongArch

ggml/src/ggml-cpu/ggml-cpu-impl.h
ggml/src/ggml-cpu/ggml-cpu-quants.c
ggml/src/ggml-cpu/ggml-cpu.c

index d71076ad12b1f656a3a3c4ee0d670397754ffc06..9ddd972a5cfcf1db5b639e40149a3275e549539c 100644 (file)
@@ -360,21 +360,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
 #endif
 
 #if defined(__loongarch_asx)
-
-typedef union {
-    int32_t i;
-    float f;
-} ft_union;
-
 /* float type data load instructions */
-static __m128 __lsx_vreplfr2vr_s(float val) {
-    ft_union fi_tmpval = {.f = val};
-    return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
+static __m128 __lsx_vreplfr2vr_s(const float val) {
+    v4f32 res = {val, val, val, val};
+    return (__m128)res;
 }
 
-static __m256 __lasx_xvreplfr2vr_s(float val) {
-    ft_union fi_tmpval = {.f = val};
-    return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
+static __m256 __lasx_xvreplfr2vr_s(const float val) {
+    v8f32 res = {val, val, val, val, val, val, val, val};
+    return (__m256)res;
 }
 #endif
 
index 72ec58ceef6d6ed01d4562b2cdada53c82c12df0..27ec149356543110fcdda6c3af3db623ece750a8 100644 (file)
@@ -501,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
 }
 
 static __m256i lasx_extu8_16(__m128i a) {
-    __m128i zero = __lsx_vldi(0);
-    __m128i vlo = __lsx_vilvl_b(zero, a);
-    __m128i vhi = __lsx_vilvh_b(zero, a);
-    return lasx_set_q(vhi, vlo);
+    return __lasx_vext2xv_hu_bu(____m256i(a));
 }
 
 static __m256i lasx_ext8_16(__m128i a) {
-     __m128i sign = __lsx_vslti_b(a, 0);
-     __m128i vlo = __lsx_vilvl_b(sign, a);
-     __m128i vhi = __lsx_vilvh_b(sign, a);
-     return lasx_set_q(vhi, vlo);
+    return __lasx_vext2xv_h_b(____m256i(a));
 }
 
 static __m256i lasx_ext16_32(__m128i a) {
-    __m256i tmp1;
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
-    return tmp1;
+    return __lasx_vext2xv_w_h(____m256i(a));
 }
 
 static __m128i lasx_extracti128( __m256i a, int pos) {
@@ -592,12 +577,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
 // horizontally add 8 floats
 static inline float hsum_float_8(const __m256 x) {
     __m128 res = lasx_extractf128(x, 1);
-    ft_union tmp;
     res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
     res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
     res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
-    tmp.i = __lsx_vpickve2gr_w(res, 0);
-    return tmp.f;
+    return ((v4f32)res)[0];
 }
 
 // horizontally add 8 int32_t
@@ -939,7 +922,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
 
 #elif defined(__loongarch_asx)
     for (int i = 0; i < nb; i++) {
-        ft_union fi;
         __m256 v0 = (__m256)__lasx_xvld( x , 0);
         __m256 v1 = (__m256)__lasx_xvld( x , 32);
         __m256 v2 = (__m256)__lasx_xvld( x , 64);
@@ -957,8 +939,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
         __m128 tmp = max4;
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
-        fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
-        const float max_scalar = fi.f;
+        const float max_scalar = ((v4f32)max4)[0];
 
         // Quantize these floats
         const float d = max_scalar / 127.f;
@@ -1263,7 +1244,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
 
 #elif defined(__loongarch_asx)
     for (int i = 0; i < nb; i++) {
-        ft_union ft;
         __m256 v0 = (__m256)__lasx_xvld( x , 0 );
         __m256 v1 = (__m256)__lasx_xvld( x , 32 );
         __m256 v2 = (__m256)__lasx_xvld( x , 64 );
@@ -1281,8 +1261,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
         __m128 tmp = max4;
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
-        ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
-        const float max_scalar = ft.f;
+        const float max_scalar = ((v4f32)max4)[0];
 
         // Quantize these floats
         const float d = max_scalar / 127.f;
@@ -6154,9 +6133,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
 
 
-    ft_union fi;
-    fi.i = __lsx_vpickve2gr_w(acc_m, 0);
-    *s = hsum_float_8(acc) + fi.f ;
+    *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
 #else
 
     const uint8_t * scales = (const uint8_t*)&utmp[0];
index e809f05d217a91aa29e6118fe610ee3d0319ae77..59efaeb71292e52e177609cc2da3db7fd4a45c76 100644 (file)
@@ -1078,29 +1078,23 @@ do {                                                              \
 #define GGML_F16_STEP 32
 #define GGML_F16_EPR  8
 
-// F16 arithmetic is not supported by AVX, so we use F32 instead
+// F16 arithmetic is not supported by LASX, so we use F32 instead
 
 #define GGML_F32Cx8          __m256
 #define GGML_F32Cx8_ZERO    (__m256)__lasx_xvldi(0)
 #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
 
 static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
-    float tmp[8];
-
-    for (int i = 0; i < 8; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(x[i]);
-    }
-
-    return (__m256)__lasx_xvld(tmp, 0);
+    __m256i a;
+    memcpy(&a, x, sizeof(ggml_fp16_t) * 8);
+    a = __lasx_xvpermi_d(a, 0 | (1 << 4));
+    return __lasx_xvfcvtl_s_h(a);
 }
-static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
-    float arr[8];
 
-    __lasx_xvst(y, arr, 0);
-
-    for (int i = 0; i < 8; i++) {
-        x[i] = GGML_FP32_TO_FP16(arr[i]);
-    }
+static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
+    __m256i a = __lasx_xvfcvt_h_s(y, y);
+    a = __lasx_xvpermi_d(a, 0 | (2 << 2));
+    memcpy(x, &a, sizeof(ggml_fp16_t) * 8);
 }
 #define GGML_F32Cx8_LOAD(x)     __lasx_f32cx8_load(x)
 #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)