]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : alternative Q4_3 implementation using modified Q8_0 (#1109)
authorGeorgi Gerganov <redacted>
Sat, 22 Apr 2023 07:55:35 +0000 (10:55 +0300)
committerGitHub <redacted>
Sat, 22 Apr 2023 07:55:35 +0000 (10:55 +0300)
* ggml : prefer vzip to vuzp

This way we always use the same type of instruction across all quantizations

* ggml : alternative Q4_3 implementation using modified Q8_0

* ggml : fix Q4_3 scalar imlpementation

* ggml : slight improvement of Q4_3 - no need for loop unrolling

* ggml : fix AVX paths for Q8_0 quantization

ggml.c

diff --git a/ggml.c b/ggml.c
index 814776381b574c828722b2eeb915c98e8752c06f..72b392fdb87e1a7c9b07856420842fdfb1d79532 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -656,10 +656,11 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong
 #define QK8_0 32
 typedef struct {
     float   d;          // delta
-    float   s;          // d * sum(qs[i])
+    float   s0;         // d * sum(qs[i]) low
+    float   s1;         // d * sum(qs[i]) high
     int8_t  qs[QK8_0];  // quants
 } block_q8_0;
-static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+static_assert(sizeof(block_q8_0) == 3*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
 
 
 // reference implementation for deterministic creation of model files
@@ -1299,13 +1300,22 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
 
         y[i].d = d;
 
-        int sum = 0;
-        for (int l = 0; l < QK8_0; ++l) {
-            const float v = x[i*QK8_0 + l]*id;
-            y[i].qs[l] = roundf(v);
-            sum += y[i].qs[l];
+        int sum0 = 0;
+        int sum1 = 0;
+
+        for (int l = 0; l < QK8_0/2; ++l) {
+            const float v0 = x[i*QK8_0           + l]*id;
+            const float v1 = x[i*QK8_0 + QK8_0/2 + l]*id;
+
+            y[i].qs[          l] = roundf(v0);
+            y[i].qs[QK8_0/2 + l] = roundf(v1);
+
+            sum0 += y[i].qs[          l];
+            sum1 += y[i].qs[QK8_0/2 + l];
         }
-        y[i].s = d * sum;
+
+        y[i].s0 = d * sum0;
+        y[i].s1 = d * sum1;
     }
 }
 
@@ -1335,9 +1345,24 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
 
         y[i].d = d;
 
-        int32x4_t accv = vdupq_n_s32(0);
+        int32x4_t accv0 = vdupq_n_s32(0);
+        int32x4_t accv1 = vdupq_n_s32(0);
 
-        for (int l = 0; l < 8; l++) {
+        // low half
+        for (int l = 0; l < 4; l++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[l], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
+
+            accv0 = vaddq_s32(accv0, vi);
+        }
+
+        // high half
+        for (int l = 4; l < 8; l++) {
             const float32x4_t v  = vmulq_n_f32(srcv[l], id);
             const int32x4_t   vi = vcvtnq_s32_f32(v);
 
@@ -1346,12 +1371,17 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
             y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
             y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
 
-            accv = vaddq_s32(accv, vi);
+            accv1 = vaddq_s32(accv1, vi);
         }
-        int32_t sum = vaddvq_s32(accv);
-        y[i].s = d * sum;
+
+        const int32_t sum0 = vaddvq_s32(accv0);
+        const int32_t sum1 = vaddvq_s32(accv1);
+
+        y[i].s0 = d * sum0;
+        y[i].s1 = d * sum1;
     }
 #elif defined(__AVX2__) || defined(__AVX__)
+    // TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
     for (int i = 0; i < nb; i++) {
         // Load elements into 4 AVX vectors
         __m256 v0 = _mm256_loadu_ps( x );
@@ -1398,7 +1428,9 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
 
 #if defined(__AVX2__)
         // Compute the sum of the quants and set y[i].s
-        y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
+        //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)));
+        y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1));
+        y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3));
 
         // Convert int32 to int16
         i0 = _mm256_packs_epi32( i0, i1 );     // 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
@@ -2395,7 +2427,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
-        sum8 += x0->d * y0->s + x1->d * y1->s;
+        sum8 += x0->d * (y0->s0 + y0->s1) + x1->d * (y1->s0 + y1->s1);
 
         const uint8x16_t m4b   = vdupq_n_u8(0xf);
 
@@ -2562,7 +2594,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
-        summs += x0->m * y0->s + x1->m * y1->s;
+        summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
 
         const uint8x16_t m4b = vdupq_n_u8(0xf);
 
@@ -2575,22 +2607,22 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
         const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
         const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 
+        // interleave
+        const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
+        const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
+        const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
+        const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
+
         // load y
         const int8x16_t v1_0l = vld1q_s8(y0->qs);
         const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
         const int8x16_t v1_1l = vld1q_s8(y1->qs);
         const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 
-        // interleave
-        const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
-        const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
-        const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
-        const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
-
 #if defined(__ARM_FEATURE_DOTPROD)
         // dot product into int32x4_t
-        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
-        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h);
 
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
@@ -2627,7 +2659,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
         const float * d0 = &x[i].d;
         const float * d1 = &y[i].d;
 
-        summs += x[i].m * y[i].s;
+        summs += x[i].m * (y[i].s0 + y[i].s1);
 
         const __m256 d0v = _mm256_broadcast_ss( d0 );
         const __m256 d1v = _mm256_broadcast_ss( d1 );
@@ -2845,88 +2877,53 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
     float32x4_t sumv0 = vdupq_n_f32(0.0f);
     float32x4_t sumv1 = vdupq_n_f32(0.0f);
 
-    for (int i = 0; i < nb; i += 2) {
+    float summs0 = 0.0f;
+    float summs1 = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
         const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
         const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
-        const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
-        const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
 
         const block_q8_0 * restrict y0 = &y[i + 0];
-        const block_q8_0 * restrict y1 = &y[i + 1];
-
-        const uint8x16_t m4b = vdupq_n_u8(0xf);
-
-        const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
-        const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
-        const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
-        const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
 
-        const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
-        const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
-        const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
-        const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
+        summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
+        summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
 
         const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
-        const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
 
         // 4-bit -> 8-bit
-        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, vdupq_n_u8(0xf)));
         const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
-        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
-        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
 
         // interleave
         const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
         const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
-        const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
-        const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
 
         // load y
         const int8x16_t v1_0l = vld1q_s8(y0->qs);
         const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
-        const int8x16_t v1_1l = vld1q_s8(y1->qs);
-        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
-
-        const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
-        const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
 
-        const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
-        const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
-
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
+        const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
+        const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
 
 #if defined(__ARM_FEATURE_DOTPROD)
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
         const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
         const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
 
-        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
-        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
-        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
-        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
-
         const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
         const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
-        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
-        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
 #endif
     }
 
-    *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+    *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
 #elif defined(__AVX2__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
@@ -2971,9 +2968,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
         const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
         const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
 
-        int sy_0 = 0;
-        int sy_1 = 0;
-
         int sxy_0 = 0;
         int sxy_1 = 0;
 
@@ -2993,15 +2987,11 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
             const int y0_1 = y0[2*(j + QK8_0/4) + 0];
             const int y1_1 = y0[2*(j + QK8_0/4) + 1];
 
-            sy_0 += y0_0 + y1_0;
-            sy_1 += y0_1 + y1_1;
-
             sxy_0 += x0_0*y0_0 + x1_0*y1_0;
             sxy_1 += x0_1*y0_1 + x1_1*y1_1;
         }
 
-        sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
-        sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
+        sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
     }
     *s = sumf;
 #endif