]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml: aarch64: SVE kernels for q8_0_q8_0, q4_0_q8_0 vector dot (llama/7433)
authorMasaya, Kato <redacted>
Sat, 25 May 2024 08:42:31 +0000 (17:42 +0900)
committerGeorgi Gerganov <redacted>
Tue, 28 May 2024 11:41:08 +0000 (14:41 +0300)
* Add SVE support for q4_0_q8_0 q8_0_q8_0

* remove ifdef

include/ggml/ggml.h
src/ggml-impl.h
src/ggml-quants.c
src/ggml.c

index be81e0c52316bed34c719cf5bdf108b3b06947c0..f803ba7241fe1b457f8ea10e93e4f72d9544288f 100644 (file)
@@ -2404,6 +2404,7 @@ extern "C" {
     GGML_API int ggml_cpu_has_avx512_bf16(void);
     GGML_API int ggml_cpu_has_fma        (void);
     GGML_API int ggml_cpu_has_neon       (void);
+    GGML_API int ggml_cpu_has_sve        (void);
     GGML_API int ggml_cpu_has_arm_fma    (void);
     GGML_API int ggml_cpu_has_metal      (void);
     GGML_API int ggml_cpu_has_f16c       (void);
index 362d40f4d1d8bb43f37944c4b88149af23f6662f..5e77471f332f443277c835f25fc916dd16fd26ca 100644 (file)
@@ -144,6 +144,10 @@ extern "C" {
 #endif
 #endif
 
+#if defined(__ARM_FEATURE_SVE)
+#include <arm_sve.h>
+#endif
+
 // 16-bit float
 // on Arm, we use __fp16
 // on x86, we use uint16_t
index bb01ce93cb9693aa81077079f5a905f4071bf841..4f2c7224c3e753ef51eb70b0c7473d99966d4ba0 100644 (file)
@@ -3813,7 +3813,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
         return;
     }
 #endif
-#if defined(__ARM_NEON)
+#if defined(__ARM_FEATURE_SVE)
+    const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
+    const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
+
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_0 * restrict x0 = &x[i + 0];
+        const block_q4_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        // load x
+        const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
+        const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
+
+        // 4-bit -> 8-bit
+        const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
+        const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
+
+        // sub 8
+        const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
+        const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
+
+        // load y
+        const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+        const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+        // dot product
+        sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+#elif defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
@@ -5384,7 +5421,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
         return;
     }
 #endif
-#if defined(__ARM_NEON)
+#if defined(__ARM_FEATURE_SVE)
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q8_0 * restrict x0 = &x[i + 0];
+        const block_q8_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        // load x
+        const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
+        const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
+
+        // load y
+        const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+        const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+        sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+#elif defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
index 9e72b7a765dbae38e6765f31051404ab103d8958..5145ceec9f4b2a41d9f6deeb37d1532013461432 100644 (file)
@@ -22742,6 +22742,16 @@ int ggml_cpu_has_neon(void) {
 #endif
 }
 
+int ggml_cpu_has_sve(void) {
+#if defined(__ARM_FEATURE_SVE)
+    // TODO: Currently, SVE 256 bit is only supported.
+    GGML_ASSERT(svcntb() == QK8_0);
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_arm_fma(void) {
 #if defined(__ARM_FEATURE_FMA)
     return 1;