]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : fix loongarch build (O2 issue) (llama/7636)
authorjunchao-loongson <redacted>
Thu, 30 May 2024 09:30:10 +0000 (17:30 +0800)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
src/ggml-quants.c
src/ggml.c

index 4f2c7224c3e753ef51eb70b0c7473d99966d4ba0..1128d66e24c363de86396e0e6fcc36b9a07594cb 100644 (file)
@@ -6828,6 +6828,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
         int bit = 0;
         int is  = 0;
+        __m256i xvbit;
 
         const uint8_t * restrict q3 = x[i].qs;
         const int8_t  * restrict q8 = y[i].qs;
@@ -6836,21 +6837,25 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
             // load low 2 bits
             const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
 
+            xvbit = __lasx_xvreplgr2vr_h(bit);
             // prepare low and high bits
             const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
-            const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
+            const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
             ++bit;
 
+            xvbit = __lasx_xvreplgr2vr_h(bit);
             const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
-            const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
+            const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
             ++bit;
 
+            xvbit = __lasx_xvreplgr2vr_h(bit);
             const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
-            const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
+            const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
             ++bit;
 
+            xvbit = __lasx_xvreplgr2vr_h(bit);
             const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
-            const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvandn_v(hbits, __lasx_xvslli_h(mone, bit)), bit), 2);
+            const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
             ++bit;
 
             // load Q8 quants
@@ -8033,6 +8038,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         __m256i sumi = __lasx_xvldi(0);
 
         int bit = 0;
+        __m256i xvbit;
 
         for (int j = 0; j < QK_K/64; ++j) {
 
@@ -8041,13 +8047,15 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
             const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
 
+            xvbit = __lasx_xvreplgr2vr_h(bit++);
             const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
-            const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
+            const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
             const __m256i q5_0  = __lasx_xvadd_b(q5l_0, q5h_0);
             hmask = __lasx_xvslli_h(hmask, 1);
 
+            xvbit = __lasx_xvreplgr2vr_h(bit++);
             const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
-            const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrli_h(__lasx_xvand_v(hbits, hmask), bit++), 4);
+            const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
             const __m256i q5_1  = __lasx_xvadd_b(q5l_1, q5h_1);
             hmask = __lasx_xvslli_h(hmask, 1);
 
index b2b725f65452c8be951817efdd3b7aa167b50015..f3a90ff2c06329ec26c04d09f327077f05051cb2 100644 (file)
@@ -1580,7 +1580,7 @@ do {                                                              \
 #define GGML_F32Cx8_ZERO    (__m256)__lasx_xvldi(0)
 #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
 
-static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
+static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t *x) {
     float tmp[8];
 
     for (int i = 0; i < 8; i++) {