]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: aarch64: SVE kernels for q8_0_q8_0, q4_0_q8_0 vector dot (#7433)
authorMasaya, Kato <redacted>
Sat, 25 May 2024 08:42:31 +0000 (17:42 +0900)
committerGitHub <redacted>
Sat, 25 May 2024 08:42:31 +0000 (11:42 +0300)
* Add SVE support for q4_0_q8_0 q8_0_q8_0

* remove ifdef

CMakeLists.txt
common/common.cpp
ggml-impl.h
ggml-quants.c
ggml.c
ggml.h
llama.cpp

index ef02ff66967f38c03589e9530ca35e808680cc32..c5add8239c2bd3bdc92a2c42e6d870e7a712bf0e 100644 (file)
@@ -72,6 +72,7 @@ else()
     set(INS_ENB ON)
 endif()
 
+option(LLAMA_SVE                             "llama: enable SVE"                                OFF)
 option(LLAMA_AVX                             "llama: enable AVX"                                ${INS_ENB})
 option(LLAMA_AVX2                            "llama: enable AVX2"                               ${INS_ENB})
 option(LLAMA_AVX512                          "llama: enable AVX512"                             OFF)
@@ -1040,6 +1041,9 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STR
             # Raspberry Pi 3, 4, Zero 2 (32-bit)
             list(APPEND ARCH_FLAGS -mno-unaligned-access)
         endif()
+        if (LLAMA_SVE)
+            list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
+        endif()
     endif()
 elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
         (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
index 401d72bac00ce6ec02b43ad9f168179df25f8925..c6459038560f137a56c5413ac9893b0dbab29faa 100644 (file)
@@ -2844,6 +2844,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
     fprintf(stream, "cpu_has_fma: %s\n",         ggml_cpu_has_fma()         ? "true" : "false");
     fprintf(stream, "cpu_has_gpublas: %s\n",     ggml_cpu_has_gpublas()     ? "true" : "false");
     fprintf(stream, "cpu_has_neon: %s\n",        ggml_cpu_has_neon()        ? "true" : "false");
+    fprintf(stream, "cpu_has_sve: %s\n",         ggml_cpu_has_sve()         ? "true" : "false");
     fprintf(stream, "cpu_has_f16c: %s\n",        ggml_cpu_has_f16c()        ? "true" : "false");
     fprintf(stream, "cpu_has_fp16_va: %s\n",     ggml_cpu_has_fp16_va()     ? "true" : "false");
     fprintf(stream, "cpu_has_wasm_simd: %s\n",   ggml_cpu_has_wasm_simd()   ? "true" : "false");
index 362d40f4d1d8bb43f37944c4b88149af23f6662f..5e77471f332f443277c835f25fc916dd16fd26ca 100644 (file)
@@ -144,6 +144,10 @@ extern "C" {
 #endif
 #endif
 
+#if defined(__ARM_FEATURE_SVE)
+#include <arm_sve.h>
+#endif
+
 // 16-bit float
 // on Arm, we use __fp16
 // on x86, we use uint16_t
index bb01ce93cb9693aa81077079f5a905f4071bf841..4f2c7224c3e753ef51eb70b0c7473d99966d4ba0 100644 (file)
@@ -3813,7 +3813,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
         return;
     }
 #endif
-#if defined(__ARM_NEON)
+#if defined(__ARM_FEATURE_SVE)
+    const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
+    const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
+
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_0 * restrict x0 = &x[i + 0];
+        const block_q4_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        // load x
+        const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
+        const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
+
+        // 4-bit -> 8-bit
+        const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
+        const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
+
+        // sub 8
+        const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
+        const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
+
+        // load y
+        const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+        const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+        // dot product
+        sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+#elif defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
@@ -5384,7 +5421,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
         return;
     }
 #endif
-#if defined(__ARM_NEON)
+#if defined(__ARM_FEATURE_SVE)
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+    assert(nb % 2 == 0); // TODO: handle odd nb
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q8_0 * restrict x0 = &x[i + 0];
+        const block_q8_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        // load x
+        const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
+        const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
+
+        // load y
+        const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+        const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+        sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+#elif defined(__ARM_NEON)
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
diff --git a/ggml.c b/ggml.c
index 9e72b7a765dbae38e6765f31051404ab103d8958..5145ceec9f4b2a41d9f6deeb37d1532013461432 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -22742,6 +22742,16 @@ int ggml_cpu_has_neon(void) {
 #endif
 }
 
+int ggml_cpu_has_sve(void) {
+#if defined(__ARM_FEATURE_SVE)
+    // TODO: Currently, SVE 256 bit is only supported.
+    GGML_ASSERT(svcntb() == QK8_0);
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_arm_fma(void) {
 #if defined(__ARM_FEATURE_FMA)
     return 1;
diff --git a/ggml.h b/ggml.h
index be81e0c52316bed34c719cf5bdf108b3b06947c0..f803ba7241fe1b457f8ea10e93e4f72d9544288f 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -2404,6 +2404,7 @@ extern "C" {
     GGML_API int ggml_cpu_has_avx512_bf16(void);
     GGML_API int ggml_cpu_has_fma        (void);
     GGML_API int ggml_cpu_has_neon       (void);
+    GGML_API int ggml_cpu_has_sve        (void);
     GGML_API int ggml_cpu_has_arm_fma    (void);
     GGML_API int ggml_cpu_has_metal      (void);
     GGML_API int ggml_cpu_has_f16c       (void);
index 3c9fe15bb459688bc222a5d91a047000e1d4fbbd..85cb3140d945b5bf26d68589548d943be0e49c77 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -18337,6 +18337,7 @@ const char * llama_print_system_info(void) {
     s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | ";
     s += "FMA = "         + std::to_string(ggml_cpu_has_fma())         + " | ";
     s += "NEON = "        + std::to_string(ggml_cpu_has_neon())        + " | ";
+    s += "SVE = "         + std::to_string(ggml_cpu_has_sve())         + " | ";
     s += "ARM_FMA = "     + std::to_string(ggml_cpu_has_arm_fma())     + " | ";
     s += "F16C = "        + std::to_string(ggml_cpu_has_f16c())        + " | ";
     s += "FP16_VA = "     + std::to_string(ggml_cpu_has_fp16_va())     + " | ";