]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Quantization imrovements for k_quants (#2707)
authorKawrakow <redacted>
Tue, 22 Aug 2023 16:14:09 +0000 (19:14 +0300)
committerGitHub <redacted>
Tue, 22 Aug 2023 16:14:09 +0000 (19:14 +0300)
* Improve LLaMA-2 2-, 3- and 4-bit quantization

* Q3_K_S: use Q5_K for 1st 2 layers of attention.wv and feed_forward.w2
* Q4_K_S: use Q6_K for 1st 2 layers of attention.wv and feed_forward.w2
* Q2_K and Q3_K_M: use Q5_K instead of Q4_K for 1st 2 layers of
  attention.wv and feed_forward.w2

This leads to a slight model sized increase as follows:
Q2_K  : 2.684G vs 2.670G
Q3_K_S: 2.775G vs 2.745G
Q3_K_M: 3.071G vs 3.057G
Q4_K_S: 3.592G vs 3.563G

LLaMA-2 PPL for context 512 changes as follows:
Q2_K  : 6.6691 vs 6.8201
Q3_K_S: 6.2129 vs 6.2584
Q3_K_M: 6.0387 vs 6.1371
Q4_K_S: 5.9138 vs 6.0041

There are improvements for LLaMA-1 as well, but they are
way smaller than the above.

* Minor 4-bit quantization improvement

For the same model size as previus commit, we get
PPL = 5.9069 vs 5.9138.

* Some more fine tuning

* Adding make_qkx2_quants

With it, we get PPL = 5.8828 for L2-7B Q4_K_S.

* Another minor improvement

* Q2_K improvement

Smaller model, lower perplexity.
 7B: file size = 2.632G, PPL = 6.3772 vs original 2.670G PPL = 6.8201
12B: file size = 5.056G, PPL = 5.4577 vs original 5.130G PPL = 5.7178

It is mostly Q3_K except for tok_embeddings, attention.wq, attention.wk,
which are Q2_K

* Iterating

* Revert Q5_K back to make_qkx1_quants

* Better Q6_K

* make_qkx2_quants is better for Q5_K after all

* Fix after rebasing on master

* Fix for changed tensor names

---------

Co-authored-by: Iwan Kawrakow <redacted>
k_quants.c
llama.cpp

index 6348fce6b94d031071302b40c10c5113850f473c..82bf816976c00c8a365bd4252d1223102366c94d 100644 (file)
@@ -77,6 +77,11 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
         }
         return 1/iscale;
     }
+    bool return_early = false;
+    if (rmse_type < 0) {
+        rmse_type = -rmse_type;
+        return_early = true;
+    }
     int weight_type = rmse_type%2;
     float sumlx = 0;
     float suml2 = 0;
@@ -89,56 +94,9 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t *
         suml2 += w*l*l;
     }
     float scale = sumlx/suml2;
+    if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
     float best = scale * sumlx;
-    for (int itry = 0; itry < 3; ++itry) {
-        iscale = 1/scale;
-        float slx = 0;
-        float sl2 = 0;
-        bool changed = false;
-        for (int i = 0; i < n; ++i) {
-            int l = nearest_int(iscale * x[i]);
-            l = MAX(-nmax, MIN(nmax-1, l));
-            if (l + nmax != L[i]) { changed = true; }
-            float w = weight_type == 1 ? x[i] * x[i] : 1.f;
-            slx += w*x[i]*l;
-            sl2 += w*l*l;
-        }
-        if (!changed || sl2 == 0 || slx*slx <= best*sl2) { break; }
-        for (int i = 0; i < n; ++i) {
-            int l = nearest_int(iscale * x[i]);
-            L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
-        }
-        sumlx = slx; suml2 = sl2;
-        scale = sumlx/suml2;
-        best = scale * sumlx;
-    }
-    for (int itry = 0; itry < 5; ++itry) {
-        int n_changed = 0;
-        for (int i = 0; i < n; ++i) {
-            float w = weight_type == 1 ? x[i]*x[i] : 1;
-            int l = L[i] - nmax;
-            float slx = sumlx - w*x[i]*l;
-            if (slx > 0) {
-                float sl2 = suml2 - w*l*l;
-                int new_l = nearest_int(x[i] * sl2 / slx);
-                new_l = MAX(-nmax, MIN(nmax-1, new_l));
-                if (new_l != l) {
-                    slx += w*x[i]*new_l;
-                    sl2 += w*new_l*new_l;
-                    if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
-                        L[i] = nmax + new_l; sumlx = slx; suml2 = sl2;
-                        scale = sumlx / suml2; best = scale * sumlx;
-                        ++n_changed;
-                    }
-                }
-            }
-        }
-        if (!n_changed) { break; }
-    }
-    if (rmse_type < 3) {
-        return scale;
-    }
-    for (int is = -4; is <= 4; ++is) {
+    for (int is = -9; is <= 9; ++is) {
         if (is == 0) {
             continue;
         }
@@ -221,12 +179,17 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t *
     return 1/iscale;
 }
 
-static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, int ntry) {
+static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
+        int ntry, float alpha) {
     float min = x[0];
     float max = x[0];
+    float sum_x = 0;
+    float sum_x2 = 0;
     for (int i = 1; i < n; ++i) {
         if (x[i] < min) min = x[i];
         if (x[i] > max) max = x[i];
+        sum_x += x[i];
+        sum_x2 += x[i]*x[i];
     }
     if (max == min) {
         for (int i = 0; i < n; ++i) L[i] = 0;
@@ -254,7 +217,7 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
         for (int i = 0; i < n; ++i) {
             sum += x[i] - scale*L[i];
         }
-        min = sum/n;
+        min = alpha*min + (1 - alpha)*sum/n;
         if (min > 0) min = 0;
         iscale = 1/scale;
         if (!did_change) break;
@@ -263,6 +226,82 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
     return scale;
 }
 
+static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+        float rmin, float rdelta, int nstep, bool use_mad) {
+    float min = x[0];
+    float max = x[0];
+    float sum_w = weights[0];
+    float sum_x = sum_w * x[0];
+    for (int i = 1; i < n; ++i) {
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+        float w = weights[i];
+        sum_w += w;
+        sum_x += w * x[i];
+    }
+    if (min > 0) min = 0;
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = -min;
+        return 0.f;
+    }
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    float best_mad = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale*(x[i] - min));
+        L[i] = MAX(0, MIN(nmax, l));
+        float diff = scale * L[i] + min - x[i];
+        diff = use_mad ? fabsf(diff) : diff * diff;
+        float w = weights[i];
+        best_mad += w * diff;
+    }
+    if (nstep < 1) {
+        *the_min = -min;
+        return scale;
+    }
+    for (int is = 0; is <= nstep; ++is) {
+        iscale = (rmin + rdelta*is + nmax)/(max - min);
+        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            Laux[i] = l;
+            float w = weights[i];
+            sum_l += w*l;
+            sum_l2 += w*l*l;
+            sum_xl += w*l*x[i];
+        }
+        float D = sum_w * sum_l2 - sum_l * sum_l;
+        if (D > 0) {
+            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+            if (this_min > 0) {
+                this_min = 0;
+                this_scale = sum_xl / sum_l2;
+            }
+            float mad = 0;
+            for (int i = 0; i < n; ++i) {
+                float diff = this_scale * Laux[i] + this_min - x[i];
+                diff = use_mad ? fabsf(diff) : diff * diff;
+                float w = weights[i];
+                mad += w * diff;
+            }
+            if (mad < best_mad) {
+                for (int i = 0; i < n; ++i) {
+                    L[i] = Laux[i];
+                }
+                best_mad = mad;
+                scale = this_scale;
+                min = this_min;
+            }
+        }
+    }
+    *the_min = -min;
+    return scale;
+}
+
 #if QK_K == 256
 static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
     if (j < 4) {
@@ -281,6 +320,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
     const int nb = k / QK_K;
 
     uint8_t L[QK_K];
+    uint8_t Laux[16];
+    float   weights[16];
     float mins[QK_K/16];
     float scales[QK_K/16];
 
@@ -291,7 +332,8 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict
         float max_scale = 0; // as we are deducting the min, scales are always positive
         float max_min = 0;
         for (int j = 0; j < QK_K/16; ++j) {
-            scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5);
+            for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
+            scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
             float scale = scales[j];
             if (scale > max_scale) {
                 max_scale = scale;
@@ -637,6 +679,8 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
     const int nb = k / QK_K;
 
     uint8_t L[QK_K];
+    uint8_t Laux[32];
+    float   weights[32];
     float mins[QK_K/32];
     float scales[QK_K/32];
 
@@ -645,7 +689,12 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
         float max_scale = 0; // as we are deducting the min, scales are always positive
         float max_min = 0;
         for (int j = 0; j < QK_K/32; ++j) {
-            scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 5);
+            //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+            float sum_x2 = 0;
+            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+            float av_x = sqrtf(sum_x2/32);
+            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
             float scale = scales[j];
             if (scale > max_scale) {
                 max_scale = scale;
@@ -798,6 +847,8 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
     uint8_t L[QK_K];
     float mins[QK_K/32];
     float scales[QK_K/32];
+    float weights[32];
+    uint8_t Laux[32];
 #else
     int8_t L[QK_K];
     float scales[QK_K/16];
@@ -810,7 +861,12 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict
         float max_scale = 0; // as we are deducting the min, scales are always positive
         float max_min = 0;
         for (int j = 0; j < QK_K/32; ++j) {
-            scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 5);
+            //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+            float sum_x2 = 0;
+            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+            float av_x = sqrtf(sum_x2/32);
+            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
             float scale = scales[j];
             if (scale > max_scale) {
                 max_scale = scale;
index 8b151dc84c90c5c024e9a1b514910dc3ff34cb5f..0584749c52c9cf240eb8817211d1ed6ae339eb24 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -3547,24 +3547,40 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                     new_type = GGML_TYPE_Q6_K;
                 }
             } else if (name.find("attn_v.weight") != std::string::npos) {
-                if      (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
+                if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
+                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
+                    new_type = i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
+                }
                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
                 else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
                         use_more_bits(i_attention_wv, n_attention_wv)) new_type = GGML_TYPE_Q6_K;
+                else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
                 else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) &&
                         (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K;
                 ++i_attention_wv;
             } else if (name.find("ffn_down.weight") != std::string::npos) {
-                if      (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
+                if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
+                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
+                    new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
+                }
                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
                 else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
                          use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K;
-                //else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K;
+                else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K;
                 ++i_feed_forward_w2;
             } else if (name.find("attn_output.weight") != std::string::npos) {
-                if      (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K;
+                if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K  ) new_type = GGML_TYPE_Q3_K;
+                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K;
                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
             }
+            else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) {
+                if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
+            }
+            // This can be used to reduce the size of the Q5_K_S model.
+            // The associated PPL increase is fully in line with the size reduction
+            //else {
+            //    if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
+            //}
             bool convert_incompatible_tensor = false;
             if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
                 new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) {