]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
arm64: optimize q4_k_q8_k kernel with i8mm (#13886)
authorYibo Cai <redacted>
Thu, 29 May 2025 11:39:20 +0000 (19:39 +0800)
committerGitHub <redacted>
Thu, 29 May 2025 11:39:20 +0000 (14:39 +0300)
This PR improves q4_k_q8_k gemm kernel with arm64 i8mm instruction.

Tested on neoverse-n2 with llama3 8b q4_k_m quantization model.
- 34% ~ 50% S_PP uplift for all batch sizes
- 12% ~ 37% S_TG uplift for batch size 4 and above

Perplexity doesn't change with this PR.

```
// tested on neoverse-n2
$ llama-batched-bench \
      -m Meta-Llama-3-8B-Instruct-Q4_K_M.gguf \
      --no-mmap -fa \
      -c 8192 -b 4096 -ub 512 -npp 128 -ntg 128 \
      -npl 1,2,4,8,16,32 \
      -t 64

---------------------------------------------------------------------
|    PP |     TG |    B |       S_PP t/s      |       S_TG t/s      |
|       |        |      | original |  this pr | original |  this pr |
|-------|--------|------|----------|----------|----------|----------|
|   128 |    128 |    1 |   110.12 |   147.83 |    24.36 |    24.28 |
|   128 |    128 |    2 |   121.16 |   172.42 |    46.36 |    47.93 |
|   128 |    128 |    4 |   120.15 |   169.75 |    74.68 |    84.00 |
|   128 |    128 |    8 |   130.97 |   196.81 |    91.04 |   114.74 |
|   128 |    128 |   16 |   131.01 |   196.88 |   101.43 |   135.79 |
|   128 |    128 |   32 |   130.85 |   196.51 |   106.97 |   147.29 |
---------------------------------------------------------------------
```

ggml/src/ggml-cpu/ggml-cpu-quants.c
ggml/src/ggml-cpu/ggml-cpu.c

index fe4a5a8336935e4416ea8607fc0e29d31f86597d..40bded4767b47aafbf6e7d9c8adf959481a7e01f 100644 (file)
@@ -6995,7 +6995,11 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
 void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
+#ifdef __ARM_FEATURE_MATMUL_INT8
+    assert((nrc == 2) || (nrc == 1));
+#else
     assert(nrc == 1);
+#endif
     UNUSED(nrc);
     UNUSED(bx);
     UNUSED(by);
@@ -7012,6 +7016,146 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
 
     uint32_t utmp[4];
 
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    if (nrc == 2) {
+        const block_q4_K * GGML_RESTRICT x0 = x;
+        const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx);
+        const block_q8_K * GGML_RESTRICT y0 = y;
+        const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by);
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0f);
+
+        float32x4_t vfsum = vdupq_n_f32(0.0f);
+
+        for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) {
+            const uint8_t * GGML_RESTRICT qx0 = x0->qs;
+            const uint8_t * GGML_RESTRICT qx1 = x1->qs;
+            const  int8_t * GGML_RESTRICT qy0 = y0->qs;
+            const  int8_t * GGML_RESTRICT qy1 = y1->qs;
+
+            // decode scales and mins
+            int8_t x0_scales[8], x1_scales[8];
+            int16x8_t x0_mins, x1_mins;
+            {
+                uint32_t scales_mins[3];
+                memcpy(scales_mins, x0->scales, 12);
+                const uint32_t mins_0_3 = scales_mins[1] & kmask1;
+                const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
+                const uint32x2_t mins = {mins_0_3, mins_4_7};
+                x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
+                uint32_t scales[2];
+                scales[0] = scales_mins[0] & kmask1; // scales 0~3
+                scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
+                memcpy(x0_scales, scales, 8);
+            }
+            {
+                uint32_t scales_mins[3];
+                memcpy(scales_mins, x1->scales, 12);
+                const uint32_t mins_0_3 = scales_mins[1] & kmask1;
+                const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4);
+                const uint32x2_t mins = {mins_0_3, mins_4_7};
+                x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins)));
+                uint32_t scales[2];
+                scales[0] = scales_mins[0] & kmask1; // scales 0~3
+                scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7
+                memcpy(x1_scales, scales, 8);
+            }
+
+            int32x4_t visum = {0};
+
+            // process 64 data points per iteration, totally 256 data points
+            for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) {
+                const int8x16x4_t vy0 = vld1q_s8_x4(qy0);
+                const int8x16x4_t vy1 = vld1q_s8_x4(qy1);
+
+                int8x16_t vx0[4], vx1[4];
+                {
+                    const uint8x16x2_t vv = vld1q_u8_x2(qx0);
+                    vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
+                    vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
+                    vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
+                    vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
+                }
+                {
+                    const uint8x16x2_t vv = vld1q_u8_x2(qx1);
+                    vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b));
+                    vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b));
+                    vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4));
+                    vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4));
+                }
+
+                // process 32 data points (share same block scale) per iteration
+                for (int k = 0; k < 2; ++k) {
+                    const int blk = j * 2 + k;
+                    const int32x4_t block_scale = {
+                        x0_scales[blk],
+                        x0_scales[blk],
+                        x1_scales[blk],
+                        x1_scales[blk],
+                    };
+
+                    int32x4_t vr = {0};
+                    for (int l = 0; l < 2; ++l) {
+                        const int idx = k * 2 + l;
+                        const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]);
+                        const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]);
+                        const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]);
+                        const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]);
+                        const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64));
+                        const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64));
+                        const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64));
+                        const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64));
+                        vr = vmmlaq_s32(vr, vx_l, vy_l);
+                        vr = vmmlaq_s32(vr, vx_h, vy_h);
+                    }
+                    // apply block scale, will NOT overflow
+                    // block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits
+                    visum = vmlaq_s32(visum, vr, block_scale);
+                }
+            }
+
+            // adjust bias, apply superblock scale
+            {
+                int32_t bias[4];
+                // no obvious uplift from sve sdot-16, just use neon mul add
+                const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8));
+                const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8));
+                bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)),
+                                               vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins))));
+                bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)),
+                                               vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins))));
+                bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)),
+                                               vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins))));
+                bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)),
+                                               vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins))));
+                const float32x4_t dmins = {
+                    GGML_FP16_TO_FP32(x0->dmin) * y0->d,
+                    GGML_FP16_TO_FP32(x0->dmin) * y1->d,
+                    GGML_FP16_TO_FP32(x1->dmin) * y0->d,
+                    GGML_FP16_TO_FP32(x1->dmin) * y1->d,
+                };
+                vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins);
+
+                const float32x4_t superblock_scale = {
+                    GGML_FP16_TO_FP32(x0->d) * y0->d,
+                    GGML_FP16_TO_FP32(x0->d) * y1->d,
+                    GGML_FP16_TO_FP32(x1->d) * y0->d,
+                    GGML_FP16_TO_FP32(x1->d) * y1->d,
+                };
+                vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale);
+            }
+        }
+
+        // vfsum = ABCD -> ACBD
+        // AC -> s, BD -> (s+bs)
+        vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2));
+        vst1_f32(s,      vget_low_f32 (vfsum));
+        vst1_f32(s + bs, vget_high_f32(vfsum));
+
+        return;
+    }
+#endif
+
 #ifdef __ARM_FEATURE_SVE
     float sumf = 0;
     for (int i = 0; i < nb; ++i) {
index aa51dc21a5de4a5d6b4053616dbdbf25e8571046..1dc425fef1419c8da4fa60597dacd67827b45862 100644 (file)
@@ -270,7 +270,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
         .from_float               = quantize_row_q4_K,
         .vec_dot                  = ggml_vec_dot_q4_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+        .nrows                    = 2,
+#else
         .nrows                    = 1,
+#endif
     },
     [GGML_TYPE_Q5_K] = {
         .from_float               = quantize_row_q5_K,