]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cpu : optimize RVV q2_k and q3_k kernels (#16887)
authorxctan <redacted>
Thu, 6 Nov 2025 16:12:45 +0000 (00:12 +0800)
committerGitHub <redacted>
Thu, 6 Nov 2025 16:12:45 +0000 (18:12 +0200)
ggml/src/ggml-cpu/arch/riscv/quants.c

index ee41a3502e82d1f4d33a8cdf9473d29e1034fad4..ae0ebb3cad11b9a236d6956e3e7ee6c91608ebfa 100644 (file)
@@ -580,16 +580,19 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
             const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
             uint8_t *patmp = atmp;
             int vsums;
-            int tmp;
+            int tmp, t1, t2, t3, t4, t5, t6, t7;
             __asm__ __volatile__(
                 "vsetivli zero, 16, e8, m1\n\t"
                 "vmv.v.x v8, zero\n\t"
+                "lb zero, 15(%[sc])\n\t"
                 "vle8.v v1, (%[sc])\n\t"
+                "vle8.v v2, (%[bsums])\n\t"
+                "addi %[tmp], %[bsums], 16\n\t"
                 "vand.vi v0, v1, 0xF\n\t"
                 "vsrl.vi v1, v1, 4\n\t"
+                "vle8.v v3, (%[tmp])\n\t"
                 "vse8.v v0, (%[scale])\n\t"
                 "vsetivli zero, 16, e16, m2\n\t"
-                "vle16.v v2, (%[bsums])\n\t"
                 "vzext.vf2 v0, v1\n\t"
                 "vwmul.vv v4, v0, v2\n\t"
                 "vsetivli zero, 16, e32, m4\n\t"
@@ -608,46 +611,89 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
             for (int j = 0; j < QK_K/128; ++j) {
                 __asm__ __volatile__(
-                    "vsetvli zero, %[vl32], e8, m2\n\t"
+                    "lb zero, 31(%[q2])\n\t"
+                    "addi %[tmp], %[q2], 16\n\t"
+                    "addi %[t1], %[q8], 16\n\t"
+                    "vsetivli zero, 16, e8, m1\n\t"
                     "vle8.v v0, (%[q2])\n\t"
+                    "vle8.v v1, (%[tmp])\n\t"
                     "vsrl.vi v2, v0, 2\n\t"
+                    "vsrl.vi v3, v1, 2\n\t"
                     "vsrl.vi v4, v0, 4\n\t"
+                    "addi %[tmp], %[q8], 32\n\t"
+                    "vle8.v v8, (%[q8])\n\t"
+                    "vle8.v v9, (%[t1])\n\t"
+                    "addi %[t1], %[t1], 32\n\t"
+                    "vsrl.vi v5, v1, 4\n\t"
                     "vsrl.vi v6, v0, 6\n\t"
+                    "vsrl.vi v7, v1, 6\n\t"
+                    "vle8.v v10, (%[tmp])\n\t"
+                    "vle8.v v11, (%[t1])\n\t"
+                    "addi %[tmp], %[tmp], 32\n\t"
+                    "addi %[t1], %[t1], 32\n\t"
                     "vand.vi v0, v0, 0x3\n\t"
+                    "vand.vi v1, v1, 0x3\n\t"
                     "vand.vi v2, v2, 0x3\n\t"
+                    "vle8.v v12, (%[tmp])\n\t"
+                    "vle8.v v13, (%[t1])\n\t"
+                    "addi %[tmp], %[tmp], 32\n\t"
+                    "addi %[t1], %[t1], 32\n\t"
+                    "vand.vi v3, v3, 0x3\n\t"
                     "vand.vi v4, v4, 0x3\n\t"
-                    "vsetvli zero, %[vl128], e8, m8\n\t"
-                    "vle8.v v8, (%[q8])\n\t"
-                    "vsetvli zero, %[vl64], e8, m4\n\t"
+                    "vand.vi v5, v5, 0x3\n\t"
+                    "vle8.v v14, (%[tmp])\n\t"
+                    "vle8.v v15, (%[t1])\n\t"
                     "vwmul.vv v16, v0, v8\n\t"
+                    "vwmul.vv v18, v1, v9\n\t"
+                    "vwmul.vv v20, v2, v10\n\t"
+                    "vwmul.vv v22, v3, v11\n\t"
                     "vwmul.vv v24, v4, v12\n\t"
-                    "vsetivli zero, 16, e16, m2\n\t"
+                    "vwmul.vv v26, v5, v13\n\t"
+                    "vwmul.vv v28, v6, v14\n\t"
+                    "vwmul.vv v30, v7, v15\n\t"
+                    "vsetivli zero, 8, e16, m1\n\t"
                     "vmv.v.x v0, zero\n\t"
-                    "vwredsum.vs v10, v16, v0\n\t"
+                    "lbu %[tmp], 0(%[scale])\n\t"
+                    "vwredsum.vs v8, v16, v0\n\t"
                     "vwredsum.vs v9, v18, v0\n\t"
-                    "vwredsum.vs v8, v20, v0\n\t"
-                    "vwredsum.vs v7, v22, v0\n\t"
-                    "vwredsum.vs v11, v24, v0\n\t"
-                    "vwredsum.vs v12, v26, v0\n\t"
-                    "vwredsum.vs v13, v28, v0\n\t"
-                    "vwredsum.vs v14, v30, v0\n\t"
+                    "lbu %[t1], 1(%[scale])\n\t"
+                    "vwredsum.vs v10, v20, v0\n\t"
+                    "vwredsum.vs v11, v22, v0\n\t"
+                    "lbu %[t2], 2(%[scale])\n\t"
+                    "vwredsum.vs v12, v24, v0\n\t"
+                    "vwredsum.vs v13, v26, v0\n\t"
+                    "lbu %[t3], 3(%[scale])\n\t"
+                    "vwredsum.vs v14, v28, v0\n\t"
+                    "vwredsum.vs v15, v30, v0\n\t"
+                    "lbu %[t4], 4(%[scale])\n\t"
+                    "vwredsum.vs v8, v17, v8\n\t"
+                    "vwredsum.vs v9, v19, v9\n\t"
+                    "lbu %[t5], 5(%[scale])\n\t"
+                    "vwredsum.vs v10, v21, v10\n\t"
+                    "vwredsum.vs v11, v23, v11\n\t"
+                    "lbu %[t6], 6(%[scale])\n\t"
+                    "vwredsum.vs v12, v25, v12\n\t"
+                    "vwredsum.vs v13, v27, v13\n\t"
+                    "lbu %[t7], 7(%[scale])\n\t"
+                    "vwredsum.vs v14, v29, v14\n\t"
+                    "vwredsum.vs v15, v31, v15\n\t"
                     "vsetivli zero, 4, e32, m1\n\t"
-                    "vslideup.vi v10, v9, 1\n\t"
-                    "vslideup.vi v8, v7, 1\n\t"
-                    "vslideup.vi v11, v12, 1\n\t"
-                    "vslideup.vi v13, v14, 1\n\t"
-                    "vslideup.vi v10, v8, 2\n\t"
-                    "vslideup.vi v11, v13, 2\n\t"
-                    "vsetivli zero, 8, e32, m2\n\t"
-                    "vle8.v v15, (%[scale])\n\t"
-                    "vzext.vf4 v12, v15\n\t"
-                    "vmul.vv v10, v10, v12\n\t"
-                    "vredsum.vs v0, v10, v0\n\t"
+                    "vmul.vx v0, v8, %[tmp]\n\t"
+                    "vmul.vx v1, v9, %[t1]\n\t"
+                    "vmacc.vx v0, %[t2], v10\n\t"
+                    "vmacc.vx v1, %[t3], v11\n\t"
+                    "vmacc.vx v0, %[t4], v12\n\t"
+                    "vmacc.vx v1, %[t5], v13\n\t"
+                    "vmacc.vx v0, %[t6], v14\n\t"
+                    "vmacc.vx v1, %[t7], v15\n\t"
                     "vmv.x.s %[tmp], v0\n\t"
-                    "add %[isum], %[isum], %[tmp]"
-                    : [tmp] "=&r" (tmp), [isum] "+&r" (isum)
+                    "vmv.x.s %[t1], v1\n\t"
+                    "add %[isum], %[isum], %[tmp]\n\t"
+                    "add %[isum], %[isum], %[t1]"
+                    : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
+                    , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
+                    , [isum] "+&r" (isum)
                     : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
-                    , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
                     : "memory"
                     , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
                     , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
@@ -929,7 +975,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
             const  int8_t * restrict q8 = y[i].qs;
 
             int8_t * scale = (int8_t *)utmp;
-            int tmp;
+            int tmp, t1, t2, t3, t4, t5, t6, t7;
             __asm__ __volatile__(
                 "vsetivli zero, 12, e8, m1\n\t"
                 "vle8.v v0, (%[s6b])\n\t"
@@ -967,19 +1013,23 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
             int isum = 0;
             for (int j = 0; j < QK_K; j += 128) {
                 __asm__ __volatile__(
+                    "lb zero, 31(%[q3])\n\t"
                     "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t"
                     "vle8.v v8, (%[q3])\n\t"
                     "vsrl.vi v10, v8, 2\n\t"
                     "vsrl.vi v12, v8, 4\n\t"
                     "vsrl.vi v14, v8, 6\n\t"
+                    "lb zero, 64(%[q8])\n\t"
                     "vand.vi v8, v8, 3\n\t"
                     "vand.vi v10, v10, 3\n\t"
                     "vand.vi v12, v12, 3\n\t"
                     "vle8.v v2, (%[qh])\n\t"
+                    "lb zero, 127(%[q8])\n\t"
                     "vand.vx v4, v2, %[m]\n\t"
                     "slli %[m], %[m], 1\n\t"
                     "vmseq.vx v0, v4, zero\n\t"
                     "vadd.vi v8, v8, -4, v0.t\n\t"
+                    "lb zero, 0(%[q8])\n\t"
                     "vand.vx v4, v2, %[m]\n\t"
                     "slli %[m], %[m], 1\n\t"
                     "vmseq.vx v0, v4, zero\n\t"
@@ -994,34 +1044,43 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
                     "vadd.vi v14, v14, -4, v0.t\n\t"
                     "vsetvli zero, %[vl128], e8, m8\n\t"
                     "vle8.v v0, (%[q8])\n\t"
+                    "lb %[tmp], 0(%[scale])\n\t"
+                    "lb %[t1], 1(%[scale])\n\t"
+                    "lb %[t2], 2(%[scale])\n\t"
+                    "lb %[t3], 3(%[scale])\n\t"
                     "vsetvli zero, %[vl64], e8, m4\n\t"
                     "vwmul.vv v16, v0, v8\n\t"
                     "vwmul.vv v24, v4, v12\n\t"
                     "vsetivli zero, 16, e16, m2\n\t"
                     "vmv.v.x v0, zero\n\t"
-                    "vwredsum.vs v10, v16, v0\n\t"
+                    "vwredsum.vs v8, v16, v0\n\t"
+                    "lb %[t4], 4(%[scale])\n\t"
+                    "lb %[t5], 5(%[scale])\n\t"
                     "vwredsum.vs v9, v18, v0\n\t"
-                    "vwredsum.vs v8, v20, v0\n\t"
-                    "vwredsum.vs v7, v22, v0\n\t"
-                    "vwredsum.vs v11, v24, v0\n\t"
-                    "vwredsum.vs v12, v26, v0\n\t"
-                    "vwredsum.vs v13, v28, v0\n\t"
-                    "vwredsum.vs v14, v30, v0\n\t"
+                    "vwredsum.vs v10, v20, v0\n\t"
+                    "vwredsum.vs v11, v22, v0\n\t"
+                    "vwredsum.vs v12, v24, v0\n\t"
+                    "lb %[t6], 6(%[scale])\n\t"
+                    "lb %[t7], 7(%[scale])\n\t"
+                    "vwredsum.vs v13, v26, v0\n\t"
+                    "vwredsum.vs v14, v28, v0\n\t"
+                    "vwredsum.vs v15, v30, v0\n\t"
                     "vsetivli zero, 4, e32, m1\n\t"
-                    "vslideup.vi v10, v9, 1\n\t"
-                    "vslideup.vi v8, v7, 1\n\t"
-                    "vslideup.vi v11, v12, 1\n\t"
-                    "vslideup.vi v13, v14, 1\n\t"
-                    "vslideup.vi v10, v8, 2\n\t"
-                    "vslideup.vi v11, v13, 2\n\t"
-                    "vsetivli zero, 8, e32, m2\n\t"
-                    "vle8.v v15, (%[scale])\n\t"
-                    "vsext.vf4 v12, v15\n\t"
-                    "vmul.vv v10, v10, v12\n\t"
-                    "vredsum.vs v0, v10, v0\n\t"
+                    "vmul.vx v0, v8, %[tmp]\n\t"
+                    "vmul.vx v1, v9, %[t1]\n\t"
+                    "vmacc.vx v0, %[t2], v10\n\t"
+                    "vmacc.vx v1, %[t3], v11\n\t"
+                    "vmacc.vx v0, %[t4], v12\n\t"
+                    "vmacc.vx v1, %[t5], v13\n\t"
+                    "vmacc.vx v0, %[t6], v14\n\t"
+                    "vmacc.vx v1, %[t7], v15\n\t"
                     "vmv.x.s %[tmp], v0\n\t"
-                    "add %[isum], %[isum], %[tmp]"
-                    : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum)
+                    "vmv.x.s %[t1], v1\n\t"
+                    "add %[isum], %[isum], %[tmp]\n\t"
+                    "add %[isum], %[isum], %[t1]"
+                    : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
+                    , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
+                    , [m] "+&r" (m), [isum] "+&r" (isum)
                     : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32)
                     , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8)
                     : "memory"