}
static __m256i lasx_extu8_16(__m128i a) {
- __m128i zero = __lsx_vldi(0);
- __m128i vlo = __lsx_vilvl_b(zero, a);
- __m128i vhi = __lsx_vilvh_b(zero, a);
- return lasx_set_q(vhi, vlo);
+ return __lasx_vext2xv_hu_bu(____m256i(a));
}
static __m256i lasx_ext8_16(__m128i a) {
- __m128i sign = __lsx_vslti_b(a, 0);
- __m128i vlo = __lsx_vilvl_b(sign, a);
- __m128i vhi = __lsx_vilvh_b(sign, a);
- return lasx_set_q(vhi, vlo);
+ return __lasx_vext2xv_h_b(____m256i(a));
}
static __m256i lasx_ext16_32(__m128i a) {
- __m256i tmp1;
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
- tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
- return tmp1;
+ return __lasx_vext2xv_w_h(____m256i(a));
}
static __m128i lasx_extracti128( __m256i a, int pos) {
// horizontally add 8 floats
static inline float hsum_float_8(const __m256 x) {
__m128 res = lasx_extractf128(x, 1);
- ft_union tmp;
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
- tmp.i = __lsx_vpickve2gr_w(res, 0);
- return tmp.f;
+ return ((v4f32)res)[0];
}
// horizontally add 8 int32_t
#elif defined(__loongarch_asx)
for (int i = 0; i < nb; i++) {
- ft_union fi;
__m256 v0 = (__m256)__lasx_xvld( x , 0);
__m256 v1 = (__m256)__lasx_xvld( x , 32);
__m256 v2 = (__m256)__lasx_xvld( x , 64);
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
__m128 tmp = max4;
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
- fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
- const float max_scalar = fi.f;
+ const float max_scalar = ((v4f32)max4)[0];
// Quantize these floats
const float d = max_scalar / 127.f;
#elif defined(__loongarch_asx)
for (int i = 0; i < nb; i++) {
- ft_union ft;
__m256 v0 = (__m256)__lasx_xvld( x , 0 );
__m256 v1 = (__m256)__lasx_xvld( x , 32 );
__m256 v2 = (__m256)__lasx_xvld( x , 64 );
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
__m128 tmp = max4;
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
- ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
- const float max_scalar = ft.f;
+ const float max_scalar = ((v4f32)max4)[0];
// Quantize these floats
const float d = max_scalar / 127.f;
acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
- ft_union fi;
- fi.i = __lsx_vpickve2gr_w(acc_m, 0);
- *s = hsum_float_8(acc) + fi.f ;
+ *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
#else
const uint8_t * scales = (const uint8_t*)&utmp[0];
#define GGML_F16_STEP 32
#define GGML_F16_EPR 8
-// F16 arithmetic is not supported by AVX, so we use F32 instead
+// F16 arithmetic is not supported by LASX, so we use F32 instead
#define GGML_F32Cx8 __m256
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
- float tmp[8];
-
- for (int i = 0; i < 8; i++) {
- tmp[i] = GGML_FP16_TO_FP32(x[i]);
- }
-
- return (__m256)__lasx_xvld(tmp, 0);
+ __m256i a;
+ memcpy(&a, x, sizeof(ggml_fp16_t) * 8);
+ a = __lasx_xvpermi_d(a, 0 | (1 << 4));
+ return __lasx_xvfcvtl_s_h(a);
}
-static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
- float arr[8];
- __lasx_xvst(y, arr, 0);
-
- for (int i = 0; i < 8; i++) {
- x[i] = GGML_FP32_TO_FP16(arr[i]);
- }
+static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
+ __m256i a = __lasx_xvfcvt_h_s(y, y);
+ a = __lasx_xvpermi_d(a, 0 | (2 << 2));
+ memcpy(x, &a, sizeof(ggml_fp16_t) * 8);
}
#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)