]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
arm64: optimize q6_k_q8_k kernel with i8mm (#13519)
authorYibo Cai <redacted>
Wed, 14 May 2025 19:53:52 +0000 (03:53 +0800)
committerGitHub <redacted>
Wed, 14 May 2025 19:53:52 +0000 (21:53 +0200)
This PR improves q6_k_q8_k gemm kernel with arm64 i8mm instruction.

Tested on neoverse-n2 with llama3 8b q6_k quantization model.
- 40% ~ 54% S_PP uplift for all batch sizes
- 16% ~ 47% S_TG uplift for batch size 4 and above

Perplexity doesn't change with this PR.

```
// tested on neoverse-n2
$ llama-batched-bench \
      -m Meta-Llama-3-8B-Instruct-Q6_K.gguf \
      --no-mmap -fa \
      -c 8192 -b 4096 -ub 512 -npp 128 -ntg 128 \
      -npl 1,2,4,8,16,32 \
      -t 64

---------------------------------------------------------------------
|    PP |     TG |    B |       S_PP t/s      |       S_TG t/s      |
|       |        |      | original |  this pr | original |  this pr |
|-------|--------|------|----------|----------|----------|----------|
|   128 |    128 |    1 |    78.52 |   109.18 |    18.63 |    18.88 |
|   128 |    128 |    2 |    84.62 |   123.94 |    34.54 |    36.92 |
|   128 |    128 |    4 |    84.36 |   122.49 |    52.65 |    61.32 |
|   128 |    128 |    8 |    90.52 |   138.87 |    63.46 |    84.41 |
|   128 |    128 |   16 |    90.11 |   138.56 |    71.04 |   101.33 |
|   128 |    128 |   32 |    89.81 |   137.79 |    75.14 |   110.47 |
---------------------------------------------------------------------
```

ggml/src/ggml-cpu/ggml-cpu-quants.c
ggml/src/ggml-cpu/ggml-cpu.c

index ccd0651ebc71402c2af2641eaece27d6314a2d4b..a89ce9bb1e93c5b5bae64593f3093b1c6b8f5abd 100644 (file)
@@ -8519,7 +8519,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
 void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+#ifdef __ARM_FEATURE_MATMUL_INT8
+    assert((nrc == 2) || (nrc == 1));
+#else
     assert(nrc == 1);
+#endif
     UNUSED(nrc);
     UNUSED(bx);
     UNUSED(by);
@@ -8530,6 +8534,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     const int nb = n / QK_K;
 
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    if (nrc == 2) {
+        const block_q6_K * GGML_RESTRICT x0 = x;
+        const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx);
+        const block_q8_K * GGML_RESTRICT y0 = y;
+        const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
+
+        float32x4_t vfsum = vdupq_n_f32(0.0f);
+
+        for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
+            const uint8_t * GGML_RESTRICT ql0 = x0->ql;
+            const uint8_t * GGML_RESTRICT ql1 = x1->ql;
+            const uint8_t * GGML_RESTRICT qh0 = x0->qh;
+            const uint8_t * GGML_RESTRICT qh1 = x1->qh;
+            const  int8_t * GGML_RESTRICT qy0 = y0->qs;
+            const  int8_t * GGML_RESTRICT qy1 = y1->qs;
+
+            const uint8x16_t mone = vdupq_n_u8(0x30);
+            const uint8x16_t  m4b = vdupq_n_u8(0x0f);
+
+            int32x4_t visum = vdupq_n_s32(0);
+
+            // process 8 blocks per iteration, totally 16 blocks
+            for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) {
+                int8x16_t vx0[8], vx1[8];
+
+                // de-quantize vx0[8]
+                {
+                    const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0);
+                    const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0);
+
+                    uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
+                    uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
+                    uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
+                    uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
+
+                    vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
+                    vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
+                    vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
+                    vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
+
+                    q6h_0 = vandq_u8(mone, qh_bits.val[0]);
+                    q6h_1 = vandq_u8(mone, qh_bits.val[1]);
+                    q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
+                    q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
+
+                    vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
+                    vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
+                    vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
+                    vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
+                }
+
+                // de-quantize vx1[8]
+                {
+                    const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1);
+                    const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1);
+
+                    uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4));
+                    uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4));
+                    uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2));
+                    uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2));
+
+                    vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0));
+                    vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1));
+                    vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2));
+                    vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3));
+
+                    q6h_0 = vandq_u8(mone, qh_bits.val[0]);
+                    q6h_1 = vandq_u8(mone, qh_bits.val[1]);
+                    q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2));
+                    q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2));
+
+                    vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0));
+                    vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1));
+                    vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2));
+                    vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3));
+                }
+
+                // process 16 elements (one block with same scale) per iteration
+                // - vx = concat(ql, qh) - 32
+                // - r1,r2,r3,r4 = smmla(vx, vy)
+                for (int k = 0; k < 8; ++k) {
+                    const int blk = j * 8 + k;
+
+                    const int8x16_t vy0 = vld1q_s8(qy0);
+                    const int8x16_t vy1 = vld1q_s8(qy1);
+                    qy0 += 16;
+                    qy1 += 16;
+
+                    const int32x4_t block_scale = {
+                        x0->scales[blk],
+                        x0->scales[blk],
+                        x1->scales[blk],
+                        x1->scales[blk],
+                    };
+
+                    // calculate four results at once with outer product
+                    const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
+                    const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k])));
+                    const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
+                    const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1)));
+                    int32x4_t vr = vdupq_n_s32(0);
+                    vr = vmmlaq_s32(vr, vx_l, vy_l);
+                    vr = vmmlaq_s32(vr, vx_h, vy_h);
+
+                    // apply block scale, will NOT overflow
+                    // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits
+                    visum = vmlaq_s32(visum, vr, block_scale);
+                }
+            }
+
+            // adjust bias, apply superblock scale
+            {
+                int32_t bias[4];
+#ifdef __ARM_FEATURE_SVE
+                const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8);
+                const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8);
+                const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums);
+                const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8);
+                const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums);
+                const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8);
+                const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales));
+                const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8));
+                const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales));
+                const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8));
+                const svint64_t zero = svdup_n_s64(0);
+                bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0),
+                                                                               svdot_s64(zero, y0_q8sums_1, x0_q6scales_1)));
+                bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0),
+                                                                               svdot_s64(zero, y1_q8sums_1, x0_q6scales_1)));
+                bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0),
+                                                                               svdot_s64(zero, y0_q8sums_1, x1_q6scales_1)));
+                bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0),
+                                                                               svdot_s64(zero, y1_q8sums_1, x1_q6scales_1)));
+#else
+                // NEON doesn't support int16 dot product, fallback to separated mul and add
+                const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums);
+                const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums);
+
+                int8x16_t scales_s8 = vld1q_s8(x0->scales);
+                const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
+                scales_s8 = vld1q_s8(x1->scales);
+                const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}};
+
+                int32x4_t prod;
+                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])),
+                                           vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))),
+                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])),
+                                           vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1]))));
+                bias[0] = vaddvq_s32(prod);
+                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])),
+                                           vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))),
+                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])),
+                                           vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1]))));
+                bias[1] = vaddvq_s32(prod);
+                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])),
+                                           vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))),
+                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])),
+                                           vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1]))));
+                bias[2] = vaddvq_s32(prod);
+                prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])),
+                                           vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))),
+                                 vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])),
+                                           vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1]))));
+                bias[3] = vaddvq_s32(prod);
+
+#endif
+                const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32);
+
+                const float32x4_t superblock_scale = {
+                    GGML_FP16_TO_FP32(x0->d) * y0->d,
+                    GGML_FP16_TO_FP32(x0->d) * y1->d,
+                    GGML_FP16_TO_FP32(x1->d) * y0->d,
+                    GGML_FP16_TO_FP32(x1->d) * y1->d,
+                };
+
+                visum = vsubq_s32(visum, vibias);
+                vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
+            }
+        }
+
+        // vfsum = ABCD -> ACBD
+        // AC -> s, BD -> (s+bs)
+        vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
+        vst1_f32(s,      vget_low_f32 (vfsum));
+        vst1_f32(s + bs, vget_high_f32(vfsum));
+
+        return;
+    }
+#endif
+
 #ifdef __ARM_FEATURE_SVE
     const int vector_length = ggml_cpu_get_sve_cnt()*8;
     float sum = 0;
index a30e67f22790015b4ac259150fa3962d0e798a4b..133b50606bcd1dd9779b7524ce78924dbee9548b 100644 (file)
@@ -282,7 +282,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
         .from_float               = quantize_row_q6_K,
         .vec_dot                  = ggml_vec_dot_q6_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+        .nrows                    = 2,
+#else
         .nrows                    = 1,
+#endif
     },
     [GGML_TYPE_IQ2_XXS] = {
         .from_float               = NULL,