]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml-cpu : optimize RVV kernels (llama/15720)
authorxctan <redacted>
Wed, 3 Sep 2025 08:16:21 +0000 (16:16 +0800)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:54:12 +0000 (12:54 +0300)
* ggml-cpu : optimize rvv ggml_vec_dot_f32

* ggml-cpu : optimize 128-bit rvv ggml_vec_dot_q4_K_q8_K

* ggml-cpu : fix riscv arch flags

* ggml-cpu : add more rvv ops

* ggml-cpu : optimize rvv ggml_vec_dot_q4_K_q8_K

* ggml-cpu : optimize rvv ggml_vec_dot_q6_K_q8_K

* ggml-cpu : minor rvv adjustments

* ggml-cpu : fix riscv include

CMakeLists.txt
src/ggml-cpu/CMakeLists.txt
src/ggml-cpu/arch/riscv/quants.c
src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/vec.cpp
src/ggml-cpu/vec.h

index 96be001f8cb7f86ed852e599a2f94557ca9719a0..9ef88c6fd0a85945c3ec4e8f101f4e31655b6ae9 100644 (file)
@@ -129,7 +129,9 @@ endif()
 option(GGML_LASX             "ggml: enable lasx"             ON)
 option(GGML_LSX              "ggml: enable lsx"              ON)
 option(GGML_RVV              "ggml: enable rvv"              ON)
-option(GGML_RV_ZFH           "ggml: enable riscv zfh"        OFF)
+option(GGML_RV_ZFH           "ggml: enable riscv zfh"        ON)
+option(GGML_RV_ZVFH          "ggml: enable riscv zvfh"       ON)
+option(GGML_RV_ZICBOP        "ggml: enable riscv zicbop"     ON)
 option(GGML_XTHEADVECTOR     "ggml: enable xtheadvector"     OFF)
 option(GGML_VXE              "ggml: enable vxe"              ON)
 option(GGML_NNPA             "ggml: enable nnpa"             OFF)  # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877
index 040b7ded90588c795e5fbad144b18df2c35732e3..dd8c1cf67840ee43130f14100805462b8ee94b81 100644 (file)
@@ -433,15 +433,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             ggml-cpu/arch/riscv/quants.c
             ggml-cpu/arch/riscv/repack.cpp
             )
-        if (GGML_RVV)
-            if (GGML_XTHEADVECTOR)
-                list(APPEND ARCH_FLAGS -march=rv64gc_zfhmin_xtheadvector -mabi=lp64d)
-            elseif (GGML_RV_ZFH)
-                list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d)
-            else()
-                list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
+        set(MARCH_STR "rv64gc")
+        if (GGML_RV_ZFH)
+            string(APPEND MARCH_STR "_zfh")
+        endif()
+        if (GGML_XTHEADVECTOR)
+            string(APPEND MARCH_STR "_xtheadvector")
+        elseif (GGML_RVV)
+            string(APPEND MARCH_STR "_v")
+            if (GGML_RV_ZVFH)
+                string(APPEND MARCH_STR "_zvfh")
             endif()
         endif()
+        if (GGML_RV_ZICBOP)
+            string(APPEND MARCH_STR "_zicbop")
+        endif()
+        list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
     elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
         message(STATUS "s390x detected")
         list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
index 6c74417c90c1f5dbdfab2e42ed5eac69aa995ac6..ee41a3502e82d1f4d33a8cdf9473d29e1034fad4 100644 (file)
@@ -1270,29 +1270,40 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
             const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
             const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
 
-            int tmp, tmp2, sumi;
+            float ftmp, ft2;
+            const uint8_t * restrict q40;
+            const uint8_t * restrict q41;
+            const uint8_t * restrict q42;
+            const uint8_t * restrict q43;
+            const int8_t  * restrict q80;
+            const int8_t  * restrict q81;
+            const int8_t  * restrict q82;
+            const int8_t  * restrict q83;
+            int s0, s1, s2, s3;
+
             __asm__ __volatile__(
-                "vsetivli zero, 12, e8, m1\n\t"
-                "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
-                "vsetivli zero, 4, e32, m1\n\t"
+                "li %[s1], 8\n\t"
+                "vsetivli zero, 4, e32, m1, ta, ma\n\t"
+                "vle32.v v1, (%[s6b])\n\t"
+                "vslide1down.vx v1, v1, zero\n\t"
+                "vmv.v.x v16, zero\n\t"
                 "vslidedown.vi v2, v1, 2\n\t"
                 "vmv1r.v v3, v2\n\t"
                 "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
-                "vsetivli zero, 2, e32, m1\n\t"
+                "vsetivli zero, 2, e32, m1, ta, ma\n\t"
                 "vmv.v.i v4, 4\n\t"
                 "vand.vx v8, v1, %[kmask1]\n\t"
                 "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
                 "vsrl.vi v6, v1, 6\n\t"
                 "vsrl.vv v7, v2, v5\n\t"
+                "vsse32.v v8, (%[utmp]), %[s1]\n\t"
                 "vand.vx v0, v6, %[kmask3]\n\t"
                 "vand.vx v2, v7, %[kmask2]\n\t"
                 "vsll.vi v6, v0, 4\n\t"
-                "li %[t2], 8\n\t"
-                "addi %[t1], %[utmp], 4\n\t"
+                "addi %[s0], %[utmp], 4\n\t"
                 "vor.vv v1, v6, v2\n\t"
-                "vsse32.v v8, (%[utmp]), %[t2]\n\t"
-                "vsse32.v v1, (%[t1]), %[t2]\n\t"
-                "vsetivli zero, 8, e16, m1\n\t"
+                "vsse32.v v1, (%[s0]), %[s1]\n\t"
+                "vsetivli zero, 8, e16, m1, ta, ma\n\t"
                 "vle32.v v2, (%[bsums])\n\t"
                 "vnsrl.wi v0, v2, 0\n\t"
                 "vnsrl.wi v1, v2, 16\n\t"
@@ -1300,13 +1311,131 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
                 "vle8.v v3, (%[mins])\n\t"
                 "vzext.vf2 v4, v3\n\t"
                 "vwmul.vv v6, v4, v2\n\t"
+                "vsetivli zero, 4, e32, m1, ta, ma\n\t"
+                "vredsum.vs v0, v6, v16\n\t"
+                "vredsum.vs v0, v7, v0\n\t"
+                "vfcvt.f.x.v v0, v0\n\t"
+                "vfmv.f.s %[ftmp], v0\n\t"
+                "vsetivli zero, 16, e8, m1, ta, ma\n\t"
+                "vle8.v v0, (%[xs])\n\t"
+                "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t"
+                "addi %[q40], %[xs], 64\n\t"
+                "addi %[q41], %[xs], 16\n\t"
+                "addi %[q42], %[xs], 32\n\t"
+                "addi %[q43], %[xs], 48\n\t"
+                "addi %[q80], %[ys], 64\n\t"
+                "vle8.v v1, (%[q41])\n\t"
+                "vle8.v v2, (%[q42])\n\t"
+                "addi %[q81], %[ys], 16\n\t"
+                "addi %[q41], %[q41], 64\n\t"
+                "addi %[q82], %[ys], 32\n\t"
+                "vle8.v v3, (%[q43])\n\t"
+                "vle8.v v8, (%[ys])\n\t"
+                "addi %[q42], %[q42], 64\n\t"
+                "addi %[q83], %[ys], 48\n\t"
+                "addi %[q43], %[q43], 64\n\t"
+                "vsrl.vi v4, v0, 4\n\t"
+                "vle8.v v9, (%[q81])\n\t"
+                "vle8.v v10, (%[q82])\n\t"
+                "vand.vi v0, v0, 0xF\n\t"
+                "addi %[q81], %[q81], 64\n\t"
+                "vsrl.vi v5, v1, 4\n\t"
+                "addi %[q82], %[q82], 64\n\t"
+                "vle8.v v11, (%[q83])\n\t"
+                "vle8.v v12, (%[q80])\n\t"
+                "vand.vi v1, v1, 0xF\n\t"
+                "addi %[q83], %[q83], 64\n\t"
+                "vsrl.vi v6, v2, 4\n\t"
+                "addi %[q80], %[q80], 64\n\t"
+                "vle8.v v13, (%[q81])\n\t"
+                "vle8.v v14, (%[q82])\n\t"
+                "vand.vi v2, v2, 0xF\n\t"
+                "addi %[q81], %[q81], 64\n\t"
+                "vsrl.vi v7, v3, 4\n\t"
+                "addi %[q82], %[q82], 64\n\t"
+                "vwmul.vv v16, v0, v8\n\t"
+                "vle8.v v15, (%[q83])\n\t"
+                "vle8.v v0, (%[q40])\n\t"
+                "vand.vi v3, v3, 0xF\n\t"
+                "addi %[q83], %[q83], 64\n\t"
+                "vwmul.vv v24, v2, v12\n\t"
+                "vwmul.vv v20, v4, v10\n\t"
+                "vwmul.vv v28, v6, v14\n\t"
+                "vwmacc.vv v16, v1, v9\n\t"
+                "vle8.v v1, (%[q41])\n\t"
+                "vle8.v v2, (%[q42])\n\t"
+                "vwmacc.vv v24, v3, v13\n\t"
+                "vwmacc.vv v20, v5, v11\n\t"
+                "vwmacc.vv v28, v7, v15\n\t"
+                "addi %[q40], %[q80], 64\n\t"
+                "addi %[q41], %[q81], 64\n\t"
+                "vle8.v v3, (%[q43])\n\t"
+                "vle8.v v8, (%[q80])\n\t"
+                "addi %[q42], %[q82], 64\n\t"
+                "addi %[q43], %[q83], 64\n\t"
+                "vsrl.vi v4, v0, 4\n\t"
+                "vle8.v v9, (%[q81])\n\t"
+                "vle8.v v10, (%[q82])\n\t"
+                "vand.vi v0, v0, 0xF\n\t"
+                "vsrl.vi v5, v1, 4\n\t"
+                "vsrl.vi v7, v3, 4\n\t"
+                "vand.vi v3, v3, 0xF\n\t"
+                "vle8.v v11, (%[q83])\n\t"
+                "vle8.v v12, (%[q40])\n\t"
+                "vand.vi v1, v1, 0xF\n\t"
+                "vsrl.vi v6, v2, 4\n\t"
+                "vand.vi v2, v2, 0xF\n\t"
+                "vwmul.vv v18, v0, v8\n\t"
+                "vle8.v v13, (%[q41])\n\t"
+                "vle8.v v14, (%[q42])\n\t"
+                "vwmul.vv v26, v2, v12\n\t"
+                "vwmul.vv v22, v4, v10\n\t"
+                "vwmul.vv v30, v6, v14\n\t"
+                "vwmacc.vv v18, v1, v9\n\t"
+                "vle8.v v15, (%[q43])\n\t"
+                "vwmacc.vv v26, v3, v13\n\t"
+                "vwmacc.vv v22, v5, v11\n\t"
+                "vwmacc.vv v30, v7, v15\n\t"
                 "vmv.v.x v0, zero\n\t"
-                "vsetivli zero, 8, e32, m2\n\t"
-                "vredsum.vs v0, v6, v0\n\t"
-                "vmv.x.s %[sumi], v0"
-                : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
-                : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
-                , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
+                "vsetivli zero, 16, e16, m2, ta, ma\n\t"
+                "vwredsum.vs v4, v16, v0\n\t"
+                "lbu %[s0], 0(%[scale])\n\t"
+                "vwredsum.vs v5, v20, v0\n\t"
+                "lbu %[s1], 1(%[scale])\n\t"
+                "vwredsum.vs v6, v24, v0\n\t"
+                "lbu %[s2], 2(%[scale])\n\t"
+                "vwredsum.vs v7, v28, v0\n\t"
+                "lbu %[s3], 3(%[scale])\n\t"
+                "vwredsum.vs v8, v18, v0\n\t"
+                "lbu %[q40], 4(%[scale])\n\t"
+                "vwredsum.vs v9, v22, v0\n\t"
+                "lbu %[q41], 5(%[scale])\n\t"
+                "vwredsum.vs v10, v26, v0\n\t"
+                "lbu %[q42], 6(%[scale])\n\t"
+                "vwredsum.vs v11, v30, v0\n\t"
+                "lbu %[q43], 7(%[scale])\n\t"
+                "vsetivli zero, 4, e32, m1, ta, ma\n\t"
+                "vmul.vx v0, v4, %[s0]\n\t"
+                "vmul.vx v1, v8, %[q40]\n\t"
+                "vmacc.vx v0, %[s1], v5\n\t"
+                "vmacc.vx v1, %[q41], v9\n\t"
+                "vmacc.vx v0, %[s2], v6\n\t"
+                "vmacc.vx v1, %[q42], v10\n\t"
+                "vmacc.vx v0, %[s3], v7\n\t"
+                "vmacc.vx v1, %[q43], v11\n\t"
+                "vfcvt.f.x.v v0, v0\n\t"
+                "vfcvt.f.x.v v1, v1\n\t"
+                "vfmv.f.s %[ft2], v0\n\t"
+                "vfmv.f.s %[ftmp], v1\n\t"
+                "fadd.s %[ft2], %[ft2], %[ftmp]\n\t"
+                "fmadd.s %[sumf], %[d], %[ft2], %[sumf]"
+                : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2)
+                , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3)
+                , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43)
+                , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83)
+                : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales)
+                , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
+                , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin)
                 , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
                 : "memory"
                 , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
@@ -1314,59 +1443,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
                 , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
                 , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
             );
-            sumf -= dmin * sumi;
-
-            const uint8_t * restrict q4 = x[i].qs;
-            const int8_t  * restrict q8 = y[i].qs;
-
-            sumi = 0;
-            const uint8_t * scale = scales;
-
-            for (int j = 0; j < QK_K/128; ++j) {
-                int vl128 = 128, vl64 = 64, vl32 = 32;
-                __asm__ __volatile__(
-                    "vsetvli zero, %[vl128], e8, m8\n\t"
-                    "vle8.v v8, (%[q8])\n\t"
-                    "vsetvli zero, %[vl64], e8, m4\n\t"
-                    "vle8.v v0, (%[q4])\n\t"
-                    "vsrl.vi v4, v0, 4\n\t"
-                    "vand.vi v0, v0, 0xF\n\t"
-                    "vsetvli zero, %[vl32], e8, m2\n\t"
-                    "vwmul.vv v28, v6, v14\n\t"
-                    "vwmul.vv v20, v4, v10\n\t"
-                    "vwmul.vv v24, v2, v12\n\t"
-                    "vwmul.vv v16, v0, v8\n\t"
-                    "vsetivli zero, 4, e32, m1\n\t"
-                    "vle8.v v2, (%[scale])\n\t"
-                    "vmv.v.x v0, zero\n\t"
-                    "vzext.vf4 v1, v2\n\t"
-                    "vsetvli zero, %[vl32], e16, m4\n\t"
-                    "vwredsum.vs v6, v24, v0\n\t"
-                    "vwredsum.vs v7, v28, v0\n\t"
-                    "vwredsum.vs v4, v16, v0\n\t"
-                    "vwredsum.vs v5, v20, v0\n\t"
-                    "vsetivli zero, 4, e32, m1\n\t"
-                    "vslideup.vi v6, v7, 1\n\t"
-                    "vslideup.vi v4, v5, 1\n\t"
-                    "vslideup.vi v4, v6, 2\n\t"
-                    "vmul.vv v8, v4, v1\n\t"
-                    "vredsum.vs v0, v8, v0\n\t"
-                    "vmv.x.s %[tmp], v0\n\t"
-                    "add %[sumi], %[sumi], %[tmp]"
-                    : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
-                    : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
-                    , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
-                    : "memory"
-                    , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
-                    , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
-                    , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
-                    , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
-                );
-
-                q4 += 64;    q8 += 128;    scale += 4;
-            }
-
-            sumf += d * sumi;
         }
         break;
     default:
@@ -1693,6 +1769,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
     case 128:
         for (int i = 0; i < nb; ++i) {
 
+            __builtin_prefetch(&x[i + 1].d, 0, 1);
+
             const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
 
             const uint8_t * restrict q6 = x[i].ql;
@@ -1701,23 +1779,59 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
             const int8_t * restrict scale = x[i].scales;
 
-            int sum_t = 0;
-            int t0;
+            int q6h;
+            float ftmp;
 
             for (int j = 0; j < QK_K/128; ++j) {
                 __asm__ __volatile__(
+                    "addi %[q6h], %[q6], 32\n\t"
+                    "ld t0, 0(%[scale])\n\t"
+                    "addi %[scale], %[scale], 8\n\t"
+                    "slli t6, t0, 1 * 8\n\t"
+                    "lb zero, 0(%[q6])\n\t"
+                    "slli t5, t0, 2 * 8\n\t"
+                    "slli t4, t0, 3 * 8\n\t"
+                    "lb zero, 0(%[q6h])\n\t"
+                    "slli t3, t0, 4 * 8\n\t"
+                    "slli t2, t0, 5 * 8\n\t"
+                    "lb zero, 0(%[qh])\n\t"
+                    "lb zero, 31(%[q6h])\n\t"
+                    "slli t1, t0, 6 * 8\n\t"
+                    "srai a7, t0, 56\n\t"
                     "vsetvli zero, %[vl32], e8, m2\n\t"
+                    "vle8.v v8, (%[q6])\n\t"
+                    "srai t6, t6, 56\n\t"
+                    "srai t5, t5, 56\n\t"
+                    "srai t4, t4, 56\n\t"
+                    "srai t3, t3, 56\n\t"
+                    "vle8.v v10, (%[q6h])\n\t"
+                    "addi %[q6], %[q6], 64\n\t"
+                    "slli t0, t0, 7 * 8\n\t"
+                    "srai t2, t2, 56\n\t"
+                    "srai t1, t1, 56\n\t"
+                    "srai t0, t0, 56\n\t"
                     "vle8.v v4, (%[qh])\n\t"
+                    "vsrl.vi v12, v8, 4\n\t"
+                    "vsrl.vi v14, v10, 4\n\t"
+                    "lb zero, 0(%[q8])\n\t"
+                    "vand.vi v8, v8, 0xF\n\t"
+                    "vand.vi v10, v10, 0xF\n\t"
+                    "lb zero, 32(%[q8])\n\t"
                     "vsll.vi v0, v4, 4\n\t"
                     "vsll.vi v2, v4, 2\n\t"
+                    "lb zero, 64(%[q8])\n\t"
                     "vsrl.vi v6, v4, 2\n\t"
-                    "vsetvli zero, %[vl64], e8, m4\n\t"
-                    "vle8.v v8, (%[q6])\n\t"
-                    "vsrl.vi v12, v8, 4\n\t"
-                    "vand.vi v8, v8, 0xF\n\t"
-                    "vsetvli zero, %[vl128], e8, m8\n\t"
                     "vand.vx v0, v0, %[mask]\n\t"
+                    "lb zero, 96(%[q8])\n\t"
+                    "vand.vx v2, v2, %[mask]\n\t"
+                    "vand.vx v4, v4, %[mask]\n\t"
+                    "vand.vx v6, v6, %[mask]\n\t"
                     "vor.vv v8, v8, v0\n\t"
+                    "lb zero, 127(%[q8])\n\t"
+                    "vor.vv v10, v10, v2\n\t"
+                    "vor.vv v12, v12, v4\n\t"
+                    "vor.vv v14, v14, v6\n\t"
+                    "vsetvli zero, %[vl128], e8, m8\n\t"
                     "vle8.v v0, (%[q8])\n\t"
                     "vsub.vx v8, v8, %[vl32]\n\t"
                     "vsetvli zero, %[vl64], e8, m4\n\t"
@@ -1734,34 +1848,34 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
                     "vwredsum.vs v13, v28, v0\n\t"
                     "vwredsum.vs v14, v30, v0\n\t"
                     "vsetivli zero, 4, e32, m1\n\t"
-                    "vslideup.vi v10, v9, 1\n\t"
-                    "vslideup.vi v8, v7, 1\n\t"
-                    "vslideup.vi v11, v12, 1\n\t"
-                    "vslideup.vi v13, v14, 1\n\t"
-                    "vslideup.vi v10, v8, 2\n\t"
-                    "vslideup.vi v11, v13, 2\n\t"
-                    "vsetivli zero, 8, e32, m2\n\t"
-                    "vle8.v v2, (%[scale])\n\t"
-                    "vsext.vf4 v4, v2\n\t"
-                    "vmul.vv v2, v4, v10\n\t"
-                    "vredsum.vs v0, v2, v0\n\t"
-                    "vmv.x.s %[t0], v0\n\t"
-                    "add %[sumi], %[sumi], %[t0]"
-                    : [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
-                    : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
+                    "vmul.vx v0, v10, t0\n\t"
+                    "vmul.vx v1, v9, t1\n\t"
+                    "vmacc.vx v0, t2, v8\n\t"
+                    "vmacc.vx v1, t3, v7\n\t"
+                    "vmacc.vx v0, t4, v11\n\t"
+                    "vmacc.vx v1, t5, v12\n\t"
+                    "vmacc.vx v0, t6, v13\n\t"
+                    "vmacc.vx v1, a7, v14\n\t"
+                    "vadd.vv v0, v0, v1\n\t"
+                    "vfcvt.f.x.v v0, v0\n\t"
+                    "vfmv.f.s %[ftmp], v0\n\t"
+                    "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]"
+                    : [q6] "+&r" (q6), [q6h] "=&r" (q6h)
+                    , [scale] "+&r" (scale)
+                    , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp)
+                    : [qh] "r" (qh), [q8] "r" (q8)
                     , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
-                    , [mask] "r" (0x30)
+                    , [mask] "r" (0x30), [d] "f" (d)
                     : "memory"
                     , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
                     , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
                     , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
                     , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
+                    , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7"
+                    , "a6", "a5", "a4", "a3"
                 );
-                q6 += 64;   qh += 32;   q8 += 128;   scale += 8;
+                qh += 32;   q8 += 128;
             }
-
-            sumf += d * sum_t;
-
         }
         break;
     default:
index 0d5d3a3440aaf77336c5430160c990a5216fcc18..78ec189d4c64671a62351339070fd4a60f5ec1c3 100644 (file)
@@ -3221,6 +3221,13 @@ void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
         uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0);
         vec_xst(v_y, 0, (ggml_fp16_t *)(y + i));
     }
+#elif defined(__riscv_zvfh)
+    for (int vl; i < n; i += vl) {
+        vl = __riscv_vsetvl_e32m2(n - i);
+        vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
+        vfloat16m1_t vy = __riscv_vfncvt_f_f_w_f16m1(vx, vl);
+        __riscv_vse16_v_f16m1((_Float16 *)&y[i], vy, vl);
+    }
 #endif
     for (; i < n; ++i) {
         y[i] = GGML_CPU_FP32_TO_FP16(x[i]);
index 0652155cf7b0ea81b4cdb063764190fe948ef61a..437192d525a34fd072ec058cc94e7085cfeaecd8 100644 (file)
@@ -85,15 +85,21 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
         // reduce sum1,sum2 to sum1
         GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
     #elif defined(__riscv_v_intrinsic)
-        vfloat32m1_t vsum = __riscv_vfmv_v_f_f32m1(0.0f, 1);
-        for (int i = 0, avl; i < n; i += avl) {
-            avl = __riscv_vsetvl_e32m8(n - i);
-            vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
-            vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
-            vfloat32m8_t prod = __riscv_vfmul_vv_f32m8(ax, ay, avl);
-            vsum = __riscv_vfredusum_vs_f32m8_f32m1(prod, vsum, avl);
+        int vl = __riscv_vsetvlmax_e32m8();
+        vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
+        vfloat32m8_t vsum;
+        vfloat32m8_t ax;
+        vfloat32m8_t ay;
+        vsum = __riscv_vfmv_v_f_f32m8_tu(vsum, 0.0f, vl);
+        for (int i = 0; i < n; i += vl) {
+            vl = __riscv_vsetvl_e32m8(n - i);
+            ax = __riscv_vle32_v_f32m8_tu(ax, &x[i], vl);
+            ay = __riscv_vle32_v_f32m8_tu(ay, &y[i], vl);
+            vsum = __riscv_vfmacc_vv_f32m8_tu(vsum, ax, ay, vl);
         }
-        sumf += __riscv_vfmv_f_s_f32m1_f32(vsum);
+        vl = __riscv_vsetvlmax_e32m8();
+        vs = __riscv_vfredusum_vs_f32m8_f32m1(vsum, vs, vl);
+        sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
     #else
         const int np = (n & ~(GGML_F32_STEP - 1));
 
@@ -208,7 +214,7 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
     ggml_float sumf = 0.0;
 
 
-#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
+#if defined(GGML_SIMD)
     #if defined(__ARM_FEATURE_SVE)
         const int sve_register_length = svcntb() * 8; //get vector length
         const int ggml_f16_epr = sve_register_length / 16; // running when 16
@@ -271,6 +277,29 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
             sum1 = svmad_f16_x(pg, hx, hy, sum1);
         }
         GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
+    #elif defined(__riscv_v_intrinsic)
+        #if defined(__riscv_zvfh)
+            int vl = __riscv_vsetvlmax_e32m2();
+            vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
+            vfloat32m2_t vsum;
+            vfloat16m1_t ax;
+            vfloat16m1_t ay;
+            vsum = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vmv_v_x_u32m2(0, vl));
+            for (int i = 0; i < n; i += vl) {
+                vl = __riscv_vsetvl_e16m1(n - i);
+                ax = __riscv_vle16_v_f16m1_tu(ax, (const _Float16 *)&x[i], vl);
+                ay = __riscv_vle16_v_f16m1_tu(ay, (const _Float16 *)&y[i], vl);
+                vsum = __riscv_vfwmacc_vv_f32m2_tu(vsum, ax, ay, vl);
+            }
+            vl = __riscv_vsetvlmax_e32m1();
+            vfloat32m1_t ac0 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum, 0), __riscv_vget_v_f32m2_f32m1(vsum, 1), vl);
+            vs = __riscv_vfredusum_vs_f32m1_f32m1(ac0, vs, vl);
+            sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
+        #else
+            for (int i = 0; i < n; ++i) {
+                sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
+            }
+        #endif // __riscv_zvfh
     #else
         const int np = (n & ~(GGML_F16_STEP - 1));
 
@@ -302,7 +331,7 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
     for (int i = 0; i < n; ++i) {
         sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
     }
-#endif
+#endif // GGML_SIMD
 
     *s = sumf;
 }
@@ -361,6 +390,14 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
     for (; i + 3 < n; i += 4) {
         vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
     }
+#elif defined(__riscv_v_intrinsic)
+    for (int vl; i < n; i += vl) {
+        vl = __riscv_vsetvl_e32m2(n - i);
+        vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
+        vfloat32m2_t vg = __riscv_vle32_v_f32m2(&g[i], vl);
+        vfloat32m2_t vy = __riscv_vfmul_vv_f32m2(ggml_v_silu_m2(vx, vl), vg, vl);
+        __riscv_vse32_v_f32m2(&y[i], vy, vl);
+    }
 #endif
     for (; i < n; ++i) {
         y[i] = ggml_silu_f32(x[i]) * g[i];
index 1346e7d7e057bac67678ca8e152ff5932e87142a..ef334d089d1f71b51296e59dde7186578e763015 100644 (file)
@@ -1269,6 +1269,14 @@ inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
         vl);
 }
 
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {
+    const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl);
+    const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl);
+    const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
+    return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl);
+}
+
 #endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
 
 inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {