]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Q4_2 quantization with rmse-optimized scale and quants (#1062)
authorKawrakow <redacted>
Wed, 19 Apr 2023 18:20:14 +0000 (20:20 +0200)
committerGitHub <redacted>
Wed, 19 Apr 2023 18:20:14 +0000 (20:20 +0200)
* Q4_2 quantization with rmse-optimized scale and quants

For quantize-stats we get
q4_2: rmse 0.00159301, maxerr 0.17480469, 95pct<0.0030, median<0.0012

For 7B perplexity with BLAS enabled we get 6.2038 after 655 chunks.

Quantization is slow (~90 seconds on my Mac for 7B) as not
multi-threaded as in PR #896.

* ggml : satisfy the sanitizer builds

Not sure why this makes them fail

* Better follow ggml conventions for function names

* Fixed type as per reviewer comment

---------

Co-authored-by: Iwan Kawrakow <redacted>
Co-authored-by: Georgi Gerganov <redacted>
ggml.c

diff --git a/ggml.c b/ggml.c
index 3b38eaad3673612fe62cbcd11c4b9588d944062c..431cdb9c907e4e3d84d87f30488b6647dc0f5fe4 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -19,6 +19,7 @@
 #include <inttypes.h>
 #include <stdio.h>
 #include <float.h>
+#include <limits.h>
 
 // if C99 - static_assert is noop
 // ref: https://stackoverflow.com/a/53923785/4039976
@@ -1135,12 +1136,94 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
     }
 }
 
+static inline int nearest_int(float fval) {
+    assert(fval <= 4194303.f);
+    float val = fval + 12582912.f;
+    int i; memcpy(&i, &val, sizeof(int));
+    return (i & 0x007fffff) - 0x00400000;
+}
+
+static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
+        const float * restrict candidates, int8_t * restrict L) {
+    assert (nmin >= INT8_MIN);
+    assert (nmax <= INT8_MAX);
+    float amax = 0;
+    for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
+    if (!amax) { // all zero
+        for (int i=0; i<n; ++i) L[i] = 0;
+        return 1.f;
+    }
+    float best = 0, bestScale = 0;
+    for (int si=0; si<nCandidates; ++si) {
+        float iscale = candidates[si]/amax;
+        float sumlxP = 0; int suml2P = 0;
+        float sumlxM = 0; int suml2M = 0;
+        for (int i=0; i<n; ++i) {
+            int l = nearest_int(iscale*X[i]);
+            int lp = MAX(nmin, MIN(nmax, +l));
+            int lm = MAX(nmin, MIN(nmax, -l));
+            sumlxP += X[i]*lp; suml2P += lp*lp;
+            sumlxM += X[i]*lm; suml2M += lm*lm;
+        }
+        float sumlxP2 = sumlxP*sumlxP;
+        float sumlxM2 = sumlxM*sumlxM;
+        if (sumlxP2*suml2M > sumlxM2*suml2P) {
+            if (sumlxP2 > best*suml2P) {
+                best = sumlxP2/suml2P; bestScale = iscale;
+            }
+        } else {
+            if (sumlxM2 > best*suml2M) {
+                best = sumlxM2/suml2M; bestScale = -iscale;
+            }
+        }
+    }
+    float sumlx = 0; int suml2 = 0;
+    for (int i=0; i<n; ++i) {
+        int l = nearest_int(bestScale*X[i]);
+        l = MAX(nmin, MIN(nmax, l));
+        sumlx += X[i]*l; suml2 += l*l;
+        L[i] = l;
+    }
+    float scale = sumlx/suml2;
+    return scale;
+}
+
+static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
+#define CANDIDATE_COUNT 8
+    static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
+    assert(k % QK4_2 == 0);
+
+    int8_t L[QK4_2];
+
+    const int nb = k / QK4_2;
+
+    for (int i = 0; i < nb; i++) {
+
+        float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
+        y[i].d = GGML_FP32_TO_FP16(scale);
+
+        for (int l = 0; l < QK4_2; l += 2) {
+            const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
+            const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
+
+            assert(vi0 < 16);
+            assert(vi1 < 16);
+
+            y[i].qs[l/2] = vi0 | (vi1 << 4);
+        }
+
+        x += QK4_2;
+    }
+}
+
 static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
     assert(k % QK4_2 == 0);
 
     block_q4_2 * restrict y = vy;
 
-    quantize_row_q4_2_reference(x, y, k);
+    //quantize_row_q4_2_reference(x, y, k);
+    // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
+    quantize_row_q4_2_rmse(x, y, k);
 }
 
 // reference implementation for deterministic creation of model files
@@ -1569,7 +1652,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_2] = {
         .dequantize_row_q         = dequantize_row_q4_2,
         .quantize_row_q           = quantize_row_q4_2,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
         .quantize_row_q_dot       = quantize_row_q8_0,
         .vec_dot_q                = ggml_vec_dot_q4_2_q8_0,
     },
@@ -11770,7 +11853,8 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
     for (int j = 0; j < n; j += k) {
         block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
 
-        quantize_row_q4_2_reference(src + j, y, k);
+        //quantize_row_q4_2_reference(src + j, y, k);
+        quantize_row_q4_2_rmse(src + j, y, k);
 
         for (int i = 0; i < nb; i++) {
             for (int l = 0; l < QK4_2; l += 2) {