]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
2-bit quantizations (#4897)
authorKawrakow <redacted>
Sun, 14 Jan 2024 07:45:56 +0000 (09:45 +0200)
committerGitHub <redacted>
Sun, 14 Jan 2024 07:45:56 +0000 (09:45 +0200)
* imatrix: load

* imatrix: WIP

* imatrix: Add Q2_K quantization

* imatrix: also guard against Q2_K_S quantization without importance matrix

* imatrix: guard even more against low-bit quantization misuse

---------

Co-authored-by: Iwan Kawrakow <redacted>
examples/benchmark/benchmark-matmult.cpp
examples/quantize/quantize.cpp
ggml-quants.c
ggml-quants.h
ggml.c
ggml.h
llama.cpp
llama.h
tests/test-backend-ops.cpp

index 434e1d6bd509eaa155a58890f393a8d224575923..e89f3de2fd397e929ae6c735070f27c71bceaa30 100644 (file)
@@ -194,7 +194,7 @@ int main(int argc, char ** argv)  {
     // Set up a the benchmark matrices
     // printf("Creating new tensor q11 & Running quantize\n");
     struct ggml_tensor * q11 = ggml_new_tensor_2d(ctx, qtype, sizex, sizey);
-    ggml_quantize_chunk(qtype, (const float *) m11->data, q11->data, 0, nelements, hist_cur.data());
+    ggml_quantize_chunk(qtype, (const float *) m11->data, q11->data, 0, nelements/m11->ne[0], m11->ne[0], hist_cur.data(), nullptr);
 
     // Set up a the compute graph
     // printf("Creating new tensor q31\n");
@@ -207,7 +207,7 @@ int main(int argc, char ** argv)  {
     // Set up a second graph computation to make sure we override the CPU cache lines
     // printf("Creating new tensor q12 & Running quantize\n");
     struct ggml_tensor * q12 = ggml_new_tensor_2d(ctx, qtype, sizex, sizey);
-    ggml_quantize_chunk(qtype, (const float *) m12->data, q12->data, 0, nelements, hist_cur.data());
+    ggml_quantize_chunk(qtype, (const float *) m12->data, q12->data, 0, nelements/m12->ne[0], m12->ne[0], hist_cur.data(), nullptr);
 
     // printf("Creating new tensor q32\n");
     struct ggml_tensor * q32 = ggml_mul_mat(ctx, q12, m2);
index f878f6911420ac6ce12c647bc5a336fbf22607a5..f4e2175f18612526e35c9f63e4bb8306eb12afee 100644 (file)
@@ -5,6 +5,10 @@
 #include <cstring>
 #include <vector>
 #include <string>
+#include <unordered_map>
+#include <fstream>
+#include <cmath>
+#include <algorithm>
 
 struct quant_option {
     std::string name;
@@ -17,6 +21,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
     { "Q4_1",   LLAMA_FTYPE_MOSTLY_Q4_1,   " 3.90G, +0.1585 ppl @ LLaMA-v1-7B", },
     { "Q5_0",   LLAMA_FTYPE_MOSTLY_Q5_0,   " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", },
     { "Q5_1",   LLAMA_FTYPE_MOSTLY_Q5_1,   " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
+    { "IQ2_XXS",LLAMA_FTYPE_MOSTLY_IQ2_XXS," 2.06 bpw quantization",            },
+    { "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization",            },
     { "Q2_K",   LLAMA_FTYPE_MOSTLY_Q2_K,   " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
     { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
     { "Q3_K",   LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
@@ -72,10 +78,14 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
 //
 [[noreturn]]
 static void usage(const char * executable) {
-    printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
+    printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
     printf("  --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
     printf("  --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
     printf("  --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
+    printf("  --imatrixfile_name: use data in file_name as importance matrix for quant optimizations\n");
+    printf("  --include-weights tensor_name: use importance matrix for this/these tensor(s)\n");
+    printf("  --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
+    printf("Note: --include-weights and --exclude-weights cannot be used together\n");
     printf("\nAllowed quantization types:\n");
     for (auto & it : QUANT_OPTIONS) {
         if (it.name != "COPY") {
@@ -83,11 +93,93 @@ static void usage(const char * executable) {
         } else {
             printf("          ");
         }
-        printf("%-6s : %s\n", it.name.c_str(), it.desc.c_str());
+        printf("%-7s : %s\n", it.name.c_str(), it.desc.c_str());
     }
     exit(1);
 }
 
+static void load_imatrix(const std::string& imatrix_file, std::unordered_map<std::string, std::vector<float>>& imatrix_data) {
+    std::ifstream in(imatrix_file.c_str(), std::ios::binary);
+    if (!in) {
+        printf("%s: failed to open %s\n",__func__,imatrix_file.c_str());
+        return;
+    }
+    int n_entries;
+    in.read((char*)&n_entries, sizeof(n_entries));
+    if (in.fail() || n_entries < 1) {
+        printf("%s: no data in file %s\n", __func__, imatrix_file.c_str());
+        return;
+    }
+    for (int i = 0; i < n_entries; ++i) {
+        int len; in.read((char *)&len, sizeof(len));
+        std::vector<char> name_as_vec(len+1);
+        in.read((char *)name_as_vec.data(), len);
+        if (in.fail()) {
+            printf("%s: failed reading name for entry %d from %s\n",__func__,i+1,imatrix_file.c_str());
+            return;
+        }
+        name_as_vec[len] = 0;
+        std::string name{name_as_vec.data()};
+        auto& e = imatrix_data[std::move(name)];
+        int ncall;
+        in.read((char*)&ncall, sizeof(ncall));
+        int nval;
+        in.read((char *)&nval, sizeof(nval));
+        if (in.fail() || nval < 1) {
+            printf("%s: failed reading number of values for entry %d\n",__func__,i);
+            imatrix_data = {};
+            return;
+        }
+        e.resize(nval);
+        in.read((char*)e.data(), nval*sizeof(float));
+        if (in.fail()) {
+            printf("%s: failed reading data for entry %d\n",__func__,i);
+            imatrix_data = {};
+            return;
+        }
+        if (ncall > 0) {
+            for (auto& v : e) v /= ncall;
+        }
+    }
+    printf("%s: loaded %d importance matrix entries from %s\n",__func__,int(imatrix_data.size()),imatrix_file.c_str());
+}
+
+static void prepare_imatrix(const std::string& imatrix_file,
+        const std::vector<std::string>& included_weights,
+        const std::vector<std::string>& excluded_weights,
+        std::unordered_map<std::string, std::vector<float>>& imatrix_data) {
+    if (!imatrix_file.empty()) {
+        load_imatrix(imatrix_file, imatrix_data);
+    }
+    if (imatrix_data.empty()) {
+        return;
+    }
+    if (!excluded_weights.empty()) {
+        for (auto& name : excluded_weights) {
+            for (auto it = imatrix_data.begin(); it != imatrix_data.end(); ) {
+                auto pos = it->first.find(name);
+                if (pos != std::string::npos) it = imatrix_data.erase(it);
+                else ++it;
+            }
+        }
+    }
+    if (!included_weights.empty()) {
+        std::unordered_map<std::string, std::vector<float>> tmp;
+        for (auto& name : included_weights) {
+            for (auto& e : imatrix_data) {
+                auto pos = e.first.find(name);
+                if (pos != std::string::npos) {
+                    tmp.emplace(std::move(e));
+                }
+            }
+        }
+        imatrix_data = std::move(tmp);
+    }
+    if (!imatrix_data.empty()) {
+        printf("%s: have %d importance matrix entries\n", __func__, int(imatrix_data.size()));
+    }
+}
+
 int main(int argc, char ** argv) {
     if (argc < 3) {
         usage(argv[0]);
@@ -96,6 +188,8 @@ int main(int argc, char ** argv) {
     llama_model_quantize_params params = llama_model_quantize_default_params();
 
     int arg_idx = 1;
+    std::string imatrix_file;
+    std::vector<std::string> included_weights, excluded_weights;
 
     for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
         if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) {
@@ -104,15 +198,43 @@ int main(int argc, char ** argv) {
             params.allow_requantize = true;
         } else if (strcmp(argv[arg_idx], "--pure") == 0) {
             params.pure = true;
+        } else if (strcmp(argv[arg_idx], "--imatrix") == 0) {
+            if (arg_idx < argc-1) {
+                imatrix_file = argv[++arg_idx];
+            } else {
+                usage(argv[0]);
+            }
+        } else if (strcmp(argv[arg_idx], "--include-weights") == 0) {
+            if (arg_idx < argc-1) {
+                included_weights.push_back(argv[++arg_idx]);
+            } else {
+                usage(argv[0]);
+            }
+        } else if (strcmp(argv[arg_idx], "--exclude-weights") == 0) {
+            if (arg_idx < argc-1) {
+                excluded_weights.push_back(argv[++arg_idx]);
+            } else {
+                usage(argv[0]);
+            }
         } else {
             usage(argv[0]);
         }
     }
 
     if (argc - arg_idx < 2) {
+        printf("%s: bad arguments\n", argv[0]);
+        usage(argv[0]);
+    }
+    if (!included_weights.empty() && !excluded_weights.empty()) {
         usage(argv[0]);
     }
 
+    std::unordered_map<std::string, std::vector<float>> imatrix_data;
+    prepare_imatrix(imatrix_file, included_weights, excluded_weights, imatrix_data);
+    if (!imatrix_data.empty()) {
+        params.imatrix = &imatrix_data;
+    }
+
     llama_backend_init(false);
 
     // parse command line arguments
@@ -163,6 +285,13 @@ int main(int argc, char ** argv) {
         }
     }
 
+    if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) && imatrix_data.empty()) {
+        fprintf(stderr, "\n===============================================================================================\n");
+        fprintf(stderr, "Please do not use IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
+        fprintf(stderr, "===============================================================================================\n\n\n");
+        return 1;
+    }
+
     print_build_info();
 
     fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str());
index 601d155d73696518854c04c316468090d3dd5b6e..9290d54cfba7a4d792f68663e26fbd2b5f46ba1f 100644 (file)
@@ -5,6 +5,8 @@
 #include <string.h>
 #include <assert.h>
 #include <float.h>
+#include <stdlib.h> // for qsort
+#include <stdio.h>  // for GGML_ASSERT
 
 #ifdef __ARM_NEON
 
@@ -1639,6 +1641,241 @@ size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n
     return (n/QK_K*sizeof(block_q2_K));
 }
 
+static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+        float rmin, float rdelta, int nstep, bool use_mad) {
+    float min = x[0];
+    float max = x[0];
+    float sum_w = weights ? weights[0] : x[0]*x[0];
+    float sum_x = sum_w * x[0];
+    for (int i = 1; i < n; ++i) {
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+        float w = weights ? weights[i] : x[i]*x[i];
+        sum_w += w;
+        sum_x += w * x[i];
+    }
+    if (min > 0) {
+        min = 0;
+    }
+    if (max <= min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = -min;
+        return 0.f;
+    }
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    float best_mad = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale*(x[i] - min));
+        L[i] = MAX(0, MIN(nmax, l));
+        float diff = scale * L[i] + min - x[i];
+        diff = use_mad ? fabsf(diff) : diff*diff;
+        float w = weights ? weights[i] : x[i]*x[i];
+        best_mad += w * diff;
+    }
+    if (nstep < 1) {
+        *the_min = -min;
+        return scale;
+    }
+    for (int is = 0; is <= nstep; ++is) {
+        iscale = (rmin + rdelta*is + nmax)/(max - min);
+        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            Laux[i] = l;
+            float w = weights ? weights[i] : x[i]*x[i];
+            sum_l  += w*l;
+            sum_l2 += w*l*l;
+            sum_xl += w*l*x[i];
+        }
+        float D = sum_w * sum_l2 - sum_l * sum_l;
+        if (D > 0) {
+            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+            if (this_min > 0) {
+                this_min = 0;
+                this_scale = sum_xl / sum_l2;
+            }
+            float mad = 0;
+            for (int i = 0; i < n; ++i) {
+                float diff = this_scale * Laux[i] + this_min - x[i];
+                diff = use_mad ? fabsf(diff) : diff*diff;
+                float w = weights ? weights[i] : x[i]*x[i];
+                mad += w * diff;
+            }
+            if (mad < best_mad) {
+                for (int i = 0; i < n; ++i) {
+                    L[i] = Laux[i];
+                }
+                best_mad = mad;
+                scale = this_scale;
+                min = this_min;
+            }
+        }
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) {
+    float max = 0;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+    }
+    if (!max) { // all zero
+        for (int i = 0; i < n; ++i) { L[i] = 0; }
+        return 0.f;
+    }
+    float iscale = nmax / max;
+    for (int i = 0; i < n; ++i) {
+        L[i] = nearest_int(iscale * x[i]);
+    }
+    float scale = 1/iscale;
+    float best_mse = 0;
+    for (int i = 0; i < n; ++i) {
+        float diff = x[i] - scale*L[i];
+        float w = quant_weights[i];
+        best_mse += w*diff*diff;
+    }
+    for (int is = -4; is <= 4; ++is) {
+        if (is == 0) continue;
+        float iscale_is = (0.1f*is + nmax)/max;
+        float scale_is = 1/iscale_is;
+        float mse = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale_is*x[i]);
+            l = MIN(nmax, l);
+            float diff = x[i] - scale_is*l;
+            float w = quant_weights[i];
+            mse += w*diff*diff;
+        }
+        if (mse < best_mse) {
+            best_mse = mse;
+            iscale = iscale_is;
+        }
+    }
+    float sumlx = 0;
+    float suml2 = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MIN(nmax, l);
+        L[i] = l;
+        float w = quant_weights[i];
+        sumlx += w*x[i]*l;
+        suml2 += w*l*l;
+    }
+    for (int itry = 0; itry < 5; ++itry) {
+        int n_changed = 0;
+        for (int i = 0; i < n; ++i) {
+            float w = quant_weights[i];
+            float slx = sumlx - w*x[i]*L[i];
+            float sl2 = suml2 - w*L[i]*L[i];
+            if (slx > 0 && sl2 > 0) {
+                int new_l = nearest_int(x[i] * sl2 / slx);
+                new_l = MIN(nmax, new_l);
+                if (new_l != L[i]) {
+                    slx += w*x[i]*new_l;
+                    sl2 += w*new_l*new_l;
+                    if (slx*slx*suml2 > sumlx*sumlx*sl2) {
+                        L[i] = new_l; sumlx = slx; suml2 = sl2;
+                        ++n_changed;
+                    }
+                }
+            }
+        }
+        if (!n_changed) {
+            break;
+        }
+    }
+    return sumlx / suml2;
+}
+
+static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) {
+    GGML_ASSERT(quant_weights);
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+    const bool requantize = true;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[16];
+    float mins[QK_K/16];
+    float scales[QK_K/16];
+    float sw[QK_K/16];
+    float weight[QK_K/16];
+    uint8_t Ls[QK_K/16], Lm[QK_K/16];
+
+    for (int i = 0; i < nb; i++) {
+        memset(sw, 0, QK_K/16*sizeof(float));
+        float sumx2 = 0;
+        for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
+        float sigma2 = sumx2/QK_K;
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float * restrict qw = quant_weights + QK_K * i + 16*j;
+            for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
+            for (int l = 0; l < 16; ++l) sw[j] += weight[l];
+            scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
+        }
+
+        float dm  = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
+        float mm  = make_qp_quants(QK_K/16, 15, mins,   Lm, sw);
+        y[i].d    = GGML_FP32_TO_FP16(dm);
+        y[i].dmin = GGML_FP32_TO_FP16(mm);
+        dm        = GGML_FP16_TO_FP32(y[i].d);
+        mm        = GGML_FP16_TO_FP32(y[i].dmin);
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            y[i].scales[j] = Ls[j] | (Lm[j] << 4);
+        }
+
+        if (requantize) {
+            for (int j = 0; j < QK_K/16; ++j) {
+                const float d = dm * (y[i].scales[j] & 0xF);
+                if (!d) continue;
+                const float m = mm * (y[i].scales[j] >> 4);
+                for (int ii = 0; ii < 16; ++ii) {
+                    int l = nearest_int((x[16*j + ii] + m)/d);
+                    l = MAX(0, MIN(3, l));
+                    L[16*j + ii] = l;
+                }
+            }
+        }
+
+#if QK_K == 256
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+#else
+        for (int l = 0; l < 16; ++l) {
+            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+        }
+#endif
+
+        x += QK_K;
+
+    }
+}
+
+size_t quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+    (void)hist;
+    int row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
+    if (!quant_weights) {
+        quantize_row_q2_K_reference(src, dst, nrow*n_per_row);
+    }
+    else {
+        char * qrow = (char *)dst;
+        for (int row = 0; row < nrow; ++row) {
+            quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);
+            src += n_per_row;
+            qrow += row_size;
+        }
+    }
+    return nrow * row_size;
+}
+
 //========================= 3-bit (de)-quantization
 
 void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
@@ -2584,14 +2821,6 @@ static const uint8_t ksigns_iq2xs[128] = {
 
 static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
 
-void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
-    (void)x;
-    (void)y;
-    (void)k;
-    assert(k % QK_K == 0);
-    //fprintf(stderr, "=========================== %s: not implemented\n", __func__);
-}
-
 void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
@@ -2618,33 +2847,8 @@ void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y
     }
 }
 
-void quantize_row_iq2_xxs(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK_K == 0);
-    block_iq2_xxs * restrict y = vy;
-    quantize_row_iq2_xxs_reference(x, y, k);
-}
-
-size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist) {
-    assert(k % QK_K == 0);
-    (void)hist; // TODO: collect histograms
-
-    for (int j = 0; j < n; j += k) {
-        block_iq2_xxs * restrict y = (block_iq2_xxs *)dst + j/QK_K;
-        quantize_row_iq2_xxs_reference(src + j, y, k);
-    }
-    return (n/QK_K*sizeof(block_iq2_xxs));
-}
-
 // ====================== 2.3125 bpw (de)-quantization
 
-void quantize_row_iq2_xs_reference(const float * restrict x, block_iq2_xs * restrict y, int k) {
-    (void)x;
-    (void)y;
-    (void)k;
-    assert(k % QK_K == 0);
-    //fprintf(stderr, "=========================== %s: not implemented\n", __func__);
-}
-
 void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
@@ -2670,23 +2874,6 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
     }
 }
 
-void quantize_row_iq2_xs(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK_K == 0);
-    block_iq2_xs * restrict y = vy;
-    quantize_row_iq2_xs_reference(x, y, k);
-}
-
-size_t ggml_quantize_iq2_xs(const float * src, void * dst, int n, int k, int64_t * hist) {
-    assert(k % QK_K == 0);
-    (void)hist; // TODO: collect histograms
-
-    for (int j = 0; j < n; j += k) {
-        block_iq2_xs * restrict y = (block_iq2_xs *)dst + j/QK_K;
-        quantize_row_iq2_xs_reference(src + j, y, k);
-    }
-    return (n/QK_K*sizeof(block_iq2_xs));
-}
-
 //===================================== Q8_K ==============================================
 
 void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -7730,3 +7917,666 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
     *s = 0.125f * sumf;
 #endif
 }
+
+// ================================ IQ2 quantization =============================================
+
+typedef struct {
+    uint64_t * grid;
+    int      * map;
+    uint16_t * neighbours;
+} iq2_entry_t;
+
+static iq2_entry_t iq2_data[2] = {
+    {NULL, NULL, NULL},
+    {NULL, NULL, NULL},
+};
+
+static inline int iq2_data_index(int grid_size) {
+    GGML_ASSERT(grid_size == 256 || grid_size == 512);
+    return grid_size == 256 ? 0 : 1;
+}
+
+static int iq2_compare_func(const void * left, const void * right) {
+    const int * l = (const int *)left;
+    const int * r = (const int *)right;
+    return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
+}
+
+static void q2xs_init_impl(int grid_size) {
+    const int gindex = iq2_data_index(grid_size);
+    if (iq2_data[gindex].grid) {
+        return;
+    }
+    static const uint16_t kgrid_256[256] = {
+            0,     2,     5,     8,    10,    17,    20,    32,    34,    40,    42,    65,    68,    80,    88,    97,
+          100,   128,   130,   138,   162,   257,   260,   272,   277,   320,   388,   408,   512,   514,   546,   642,
+         1025,  1028,  1040,  1057,  1060,  1088,  1090,  1096,  1120,  1153,  1156,  1168,  1188,  1280,  1282,  1288,
+         1312,  1350,  1385,  1408,  1425,  1545,  1552,  1600,  1668,  1700,  2048,  2053,  2056,  2068,  2088,  2113,
+         2116,  2128,  2130,  2184,  2308,  2368,  2562,  2580,  4097,  4100,  4112,  4129,  4160,  4192,  4228,  4240,
+         4245,  4352,  4360,  4384,  4432,  4442,  4480,  4644,  4677,  5120,  5128,  5152,  5157,  5193,  5248,  5400,
+         5474,  5632,  5654,  6145,  6148,  6160,  6208,  6273,  6400,  6405,  6560,  6737,  8192,  8194,  8202,  8260,
+         8289,  8320,  8322,  8489,  8520,  8704,  8706,  9217,  9220,  9232,  9280,  9302,  9472,  9537,  9572,  9872,
+        10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516,
+        16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561,
+        17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488,
+        20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545,
+        22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874,
+        25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856,
+        33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
+        37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
+    };
+    static const uint16_t kgrid_512[512] = {
+            0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,
+           73,    80,    82,    85,    88,    97,   100,   128,   130,   133,   136,   145,   148,   153,   160,   257,
+          260,   262,   265,   272,   274,   277,   280,   282,   289,   292,   320,   322,   325,   328,   337,   340,
+          352,   360,   385,   388,   400,   512,   514,   517,   520,   529,   532,   544,   577,   580,   592,   597,
+          640,   650,  1025,  1028,  1030,  1033,  1040,  1042,  1045,  1048,  1057,  1060,  1088,  1090,  1093,  1096,
+         1105,  1108,  1110,  1120,  1153,  1156,  1168,  1280,  1282,  1285,  1288,  1297,  1300,  1312,  1345,  1348,
+         1360,  1377,  1408,  1537,  1540,  1552,  1574,  1600,  1602,  1668,  2048,  2050,  2053,  2056,  2058,  2065,
+         2068,  2080,  2085,  2113,  2116,  2128,  2136,  2176,  2208,  2218,  2305,  2308,  2320,  2368,  2433,  2441,
+         2560,  2592,  2600,  2710,  2720,  4097,  4100,  4102,  4105,  4112,  4114,  4117,  4120,  4129,  4132,  4160,
+         4162,  4165,  4168,  4177,  4180,  4192,  4202,  4225,  4228,  4240,  4352,  4354,  4357,  4360,  4369,  4372,
+         4384,  4417,  4420,  4432,  4480,  4500,  4502,  4609,  4612,  4614,  4624,  4672,  4704,  5120,  5122,  5125,
+         5128,  5137,  5140,  5152,  5185,  5188,  5193,  5200,  5220,  5248,  5377,  5380,  5392,  5440,  5632,  5652,
+         5705,  6145,  6148,  6160,  6162,  6208,  6228,  6278,  6400,  6405,  6502,  6737,  6825,  8192,  8194,  8197,
+         8200,  8202,  8209,  8212,  8224,  8257,  8260,  8272,  8320,  8352,  8449,  8452,  8464,  8512,  8520,  8549,
+         8704,  8738,  8832,  8872,  9217,  9220,  9232,  9257,  9280,  9472,  9537,  9554,  9625,  9729,  9754,  9894,
+        10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388,
+        16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480,
+        16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773,
+        16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473,
+        17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436,
+        18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497,
+        20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162,
+        21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528,
+        22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745,
+        24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234,
+        32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025,
+        33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810,
+        33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984,
+        35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462,
+        37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960,
+        40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
+        42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
+    };
+    const int kmap_size = 43692;
+    const int nwant = 2;
+    const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
+    uint64_t * kgrid_q2xs;
+    int      * kmap_q2xs;
+    uint16_t * kneighbors_q2xs;
+
+    printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
+    uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t));
+    for (int k = 0; k < grid_size; ++k) {
+        int8_t * pos = (int8_t *)(the_grid + k);
+        for (int i = 0; i < 8; ++i) {
+            int l = (kgrid[k] >> 2*i) & 0x3;
+            pos[i] = 2*l + 1;
+        }
+    }
+    kgrid_q2xs = the_grid;
+    iq2_data[gindex].grid = the_grid;
+    kmap_q2xs = (int *)malloc(kmap_size*sizeof(int));
+    iq2_data[gindex].map = kmap_q2xs;
+    for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1;
+    uint64_t aux64;
+    uint8_t * aux8 = (uint8_t *)&aux64;
+    for (int i = 0; i < grid_size; ++i) {
+        aux64 = kgrid_q2xs[i];
+        uint16_t index = 0;
+        for (int k=0; k<8; ++k) {
+            uint16_t q = (aux8[k] - 1)/2;
+            index |= (q << 2*k);
+        }
+        kmap_q2xs[index] = i;
+    }
+    int8_t pos[8];
+    int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
+    int num_neighbors = 0, num_not_in_map = 0;
+    for (int i = 0; i < kmap_size; ++i) {
+        if (kmap_q2xs[i] >= 0) continue;
+        ++num_not_in_map;
+        for (int k = 0; k < 8; ++k) {
+            int l = (i >> 2*k) & 0x3;
+            pos[k] = 2*l + 1;
+        }
+        for (int j = 0; j < grid_size; ++j) {
+            const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
+            int d2 = 0;
+            for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+            dist2[2*j+0] = d2;
+            dist2[2*j+1] = j;
+        }
+        qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
+        int n = 0; int d2 = dist2[0];
+        int nhave = 1;
+        for (int j = 0; j < grid_size; ++j) {
+            if (dist2[2*j] > d2) {
+                if (nhave == nwant) break;
+                d2 = dist2[2*j];
+                ++nhave;
+            }
+            ++n;
+        }
+        num_neighbors += n;
+    }
+    printf("%s: %d neighbours in total\n", __func__, num_neighbors);
+    kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
+    iq2_data[gindex].neighbours = kneighbors_q2xs;
+    int counter = 0;
+    for (int i = 0; i < kmap_size; ++i) {
+        if (kmap_q2xs[i] >= 0) continue;
+        for (int k = 0; k < 8; ++k) {
+            int l = (i >> 2*k) & 0x3;
+            pos[k] = 2*l + 1;
+        }
+        for (int j = 0; j < grid_size; ++j) {
+            const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
+            int d2 = 0;
+            for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+            dist2[2*j+0] = d2;
+            dist2[2*j+1] = j;
+        }
+        qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
+        kmap_q2xs[i] = -(counter + 1);
+        int d2 = dist2[0];
+        uint16_t * start = &kneighbors_q2xs[counter++];
+        int n = 0, nhave = 1;
+        for (int j = 0; j < grid_size; ++j) {
+            if (dist2[2*j] > d2) {
+                if (nhave == nwant) break;
+                d2 = dist2[2*j];
+                ++nhave;
+            }
+            kneighbors_q2xs[counter++] = dist2[2*j+1];
+            ++n;
+        }
+        *start = n;
+    }
+    free(dist2);
+}
+
+void ggml_init_iq2_quantization(enum ggml_type type) {
+    if (type == GGML_TYPE_IQ2_XXS) {
+        q2xs_init_impl(256);
+    }
+    else if (type == GGML_TYPE_IQ2_XS) {
+        q2xs_init_impl(512);
+    }
+    else {
+        fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
+    }
+}
+
+static void q2xs_deinit_impl(int grid_size) {
+    GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
+    const int gindex = iq2_data_index(grid_size);
+    if (iq2_data[gindex].grid) {
+        free(iq2_data[gindex].grid);       iq2_data[gindex].grid = NULL;
+        free(iq2_data[gindex].map);        iq2_data[gindex].map  = NULL;
+        free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL;
+    }
+}
+
+void ggml_deinit_iq2_quantization(enum ggml_type type) {
+    if (type == GGML_TYPE_IQ2_XXS) {
+        q2xs_deinit_impl(256);
+    }
+    else if (type == GGML_TYPE_IQ2_XS) {
+        q2xs_deinit_impl(512);
+    }
+    else {
+        fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
+    }
+}
+
+static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
+        const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
+    int num_neighbors = neighbours[0];
+    GGML_ASSERT(num_neighbors > 0);
+    float best_d2 = FLT_MAX;
+    int grid_index = -1;
+    for (int j = 1; j <= num_neighbors; ++j) {
+        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+        float d2 = 0;
+        for (int i = 0; i < 8; ++i) {
+            float q = pg[i];
+            float diff = scale*q - xval[i];
+            d2 += weight[i]*diff*diff;
+        }
+        if (d2 < best_d2) {
+            best_d2 = d2; grid_index = neighbours[j];
+        }
+    }
+    GGML_ASSERT(grid_index >= 0);
+    const int8_t * pg = (const int8_t *)(grid + grid_index);
+    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
+    return grid_index;
+}
+
+static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+
+    const int gindex = iq2_data_index(256);
+
+    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
+    const int      * kmap_q2xs       = iq2_data[gindex].map;
+    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+    GGML_ASSERT(quant_weights);
+    GGML_ASSERT(kgrid_q2xs);
+    GGML_ASSERT(kmap_q2xs);
+    GGML_ASSERT(kneighbors_q2xs);
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int kMaxQ = 3;
+
+    const int nbl = n/256;
+
+    block_iq2_xxs * y = vy;
+
+    float scales[QK_K/32];
+    float weight[32];
+    float xval[32];
+    int8_t L[32];
+    int8_t Laux[32];
+    float  waux[32];
+    bool   is_on_grid[4];
+    bool   is_on_grid_aux[4];
+    uint8_t block_signs[4];
+    uint32_t q2[2*(QK_K/32)];
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+        memset(q2, 0, QK_K/4);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const float * xb = xbl + 32*ib;
+            const float * qw = quant_weights + QK_K*ibl + 32*ib;
+            for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
+            for (int k = 0; k < 4; ++k) {
+                int nflip = 0;
+                uint8_t s = 0;
+                for (int i = 0; i < 8; ++i) {
+                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+                    else {
+                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+                    }
+                }
+                if (nflip%2) {
+                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+                    for (int i = 1; i < 8; ++i) {
+                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+                        if (ax < min) {
+                            min = ax; imin = i;
+                        }
+                    }
+                    xval[8*k+imin] = -xval[8*k+imin];
+                    s ^= (1 << imin);
+                }
+                block_signs[k] = s & 127;
+            }
+            float max = xval[0];
+            for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
+            if (!max) {
+                scales[ib] = 0;
+                memset(L, 0, 32);
+                continue;
+            }
+            float best = 0;
+            float scale = max/(2*kMaxQ-1);
+            for (int is = -9; is <= 9; ++is) {
+                float id = (2*kMaxQ-1+is*0.1f)/max;
+                float this_scale = 1/id;
+                for (int k = 0; k < 4; ++k) {
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
+                    }
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
+                    int grid_index = kmap_q2xs[u];
+                    is_on_grid_aux[k] = true;
+                    if (grid_index < 0) {
+                        is_on_grid_aux[k] = false;
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 32; ++i) {
+                    float w = weight[i];
+                    float q = 2*Laux[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                    scale = sumqx/sumq2; best = scale*sumqx;
+                    for (int i = 0; i < 32; ++i) L[i] = Laux[i];
+                    for (int k = 0; k <  4; ++k) is_on_grid[k] = is_on_grid_aux[k];
+                }
+            }
+            int n_not_ongrid = 0;
+            for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+            if (n_not_ongrid > 0 && scale > 0) {
+                float id = 1/scale;
+                for (int k = 0; k < 4; ++k) {
+                    if (is_on_grid[k]) continue;
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        l = MAX(0, MIN(kMaxQ-1, l));
+                        u |= (l << 2*i);
+                    }
+                    int grid_index = kmap_q2xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
+                    }
+                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index);
+                    for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2;
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 32; ++i) {
+                    float w = weight[i];
+                    float q = 2*L[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0) scale = sumqx/sumq2;
+            }
+            if (scale < 0) {
+                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
+                // and correspondingly flip quant signs.
+                scale = -scale;
+                for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
+            }
+            for (int k = 0; k < 4; ++k) {
+                uint16_t u = 0;
+                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
+                int grid_index = kmap_q2xs[u];
+                if (grid_index < 0) {
+                    printf("Oops: found point %u not on grid:", u);
+                    for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
+                    printf("\n");
+                    GGML_ASSERT(false);
+                }
+                q2[2*ib+0] |= (grid_index << 8*k);
+                q2[2*ib+1] |= (block_signs[k] << 7*k);
+            }
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            memset(y[ibl].qs, 0, QK_K/4);
+            continue;
+        }
+
+        float d = max_scale/31;
+        y[ibl].d = GGML_FP32_TO_FP16(d);
+        float id = 1/d;
+        float sumqx = 0, sumq2 = 0;
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib]-1));
+            l = MAX(0, MIN(15, l));
+            q2[2*ib+1] |= ((uint32_t)l << 28);
+            const float * xb = xbl + 32*ib;
+            const float * qw = quant_weights + QK_K*ibl + 32*ib;
+            for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
+            const float db = d * (1 + 2*l);
+            uint32_t u = 0;
+            for (int k = 0; k < 4; ++k) {
+                const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
+                const float * xk = xb + 8*k;
+                const float * wk = weight + 8*k;
+                const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
+                float best_mse = 0; int best_index = aux8[k];
+                for (int j = 0; j < 8; ++j) {
+                    float diff = db * grid[j] * signs[j] - xk[j];
+                    best_mse += wk[j] * diff * diff;
+                }
+                for (int idx = 0; idx < 256; ++idx) {
+                    grid = (const uint8_t *)(kgrid_q2xs + idx);
+                    float mse = 0;
+                    for (int j = 0; j < 8; ++j) {
+                        float diff = db * grid[j] * signs[j] - xk[j];
+                        mse += wk[j] * diff * diff;
+                    }
+                    if (mse < best_mse) {
+                        best_mse = mse; best_index = idx;
+                    }
+                }
+                u |= (best_index << 8*k);
+                grid = (const uint8_t *)(kgrid_q2xs + best_index);
+                //grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
+                for (int j = 0; j < 8; ++j) {
+                    float q = db * grid[j] * signs[j];
+                    sumqx += wk[j] * q * xk[j];
+                    sumq2 += wk[j] * q * q;
+                }
+            }
+            q2[2*ib] = u;
+            if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
+        }
+        memcpy(y[ibl].qs, q2, QK_K/4);
+    }
+}
+
+static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+
+    const int gindex = iq2_data_index(512);
+
+    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
+    const int      * kmap_q2xs       = iq2_data[gindex].map;
+    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+    GGML_ASSERT(quant_weights);
+    GGML_ASSERT(kmap_q2xs);
+    GGML_ASSERT(kgrid_q2xs);
+    GGML_ASSERT(kneighbors_q2xs);
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int kMaxQ = 3;
+
+    const int nbl = n/256;
+
+    block_iq2_xs * y = vy;
+
+    float scales[QK_K/16];
+    float weight[16];
+    float xval[16];
+    int8_t L[16];
+    int8_t Laux[16];
+    float  waux[16];
+    bool   is_on_grid[2];
+    bool   is_on_grid_aux[2];
+    uint8_t block_signs[2];
+    uint16_t q2[2*(QK_K/16)];
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+        memset(q2, 0, QK_K/4);
+        memset(y[ibl].scales, 0, QK_K/32);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            const float * xb = xbl + 16*ib;
+            const float * qw = quant_weights + QK_K*ibl + 16*ib;
+            for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
+            for (int k = 0; k < 2; ++k) {
+                int nflip = 0;
+                uint8_t s = 0;
+                for (int i = 0; i < 8; ++i) {
+                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+                    else {
+                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+                    }
+                }
+                if (nflip%2) {
+                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+                    for (int i = 1; i < 8; ++i) {
+                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+                        if (ax < min) {
+                            min = ax; imin = i;
+                        }
+                    }
+                    xval[8*k+imin] = -xval[8*k+imin];
+                    s ^= (1 << imin);
+                }
+                block_signs[k] = s & 127;
+            }
+            float max = xval[0];
+            for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
+            if (!max) {
+                scales[ib] = 0;
+                memset(L, 0, 16);
+                continue;
+            }
+            float best = 0;
+            float scale = max/(2*kMaxQ-1);
+            is_on_grid[0] = is_on_grid[1] = true;
+            for (int is = -9; is <= 9; ++is) {
+                float id = (2*kMaxQ-1+is*0.1f)/max;
+                float this_scale = 1/id;
+                for (int k = 0; k < 2; ++k) {
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
+                    }
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
+                    int grid_index = kmap_q2xs[u];
+                    is_on_grid_aux[k] = true;
+                    if (grid_index < 0) {
+                        is_on_grid_aux[k] = false;
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 16; ++i) {
+                    float w = weight[i];
+                    float q = 2*Laux[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                    scale = sumqx/sumq2; best = scale*sumqx;
+                    for (int i = 0; i < 16; ++i) L[i] = Laux[i];
+                    for (int k = 0; k <  2; ++k) is_on_grid[k] = is_on_grid_aux[k];
+                }
+            }
+            int n_not_ongrid = 0;
+            for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+            if (n_not_ongrid > 0 && scale > 0) {
+                float id = 1/scale;
+                for (int k = 0; k < 2; ++k) {
+                    if (is_on_grid[k]) continue;
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        l = MAX(0, MIN(kMaxQ-1, l));
+                        u |= (l << 2*i);
+                        L[8*k + i] = l;
+                    }
+                    int grid_index = kmap_q2xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 16; ++i) {
+                    float w = weight[i];
+                    float q = 2*L[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0) scale = sumqx/sumq2;
+            }
+            if (scale < 0) {
+                scale = -scale;
+                for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127;
+            }
+            for (int k = 0; k < 2; ++k) {
+                uint16_t u = 0;
+                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
+                int grid_index = kmap_q2xs[u];
+                if (grid_index < 0) {
+                    printf("Oops: found point %u not on grid:", u);
+                    for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
+                    printf("\n");
+                    GGML_ASSERT(false);
+                }
+                q2[2*ib+k] = grid_index | (block_signs[k] << 9);
+            }
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            memset(y[ibl].qs, 0, QK_K/4);
+            continue;
+        }
+
+        float d = max_scale/31;
+        y[ibl].d = GGML_FP32_TO_FP16(d);
+        float id = 1/d;
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib]-1));
+            l = MAX(0, MIN(15, l));
+            if (ib%2 == 0) y[ibl].scales[ib/2] = l;
+            else y[ibl].scales[ib/2] |= (l << 4);
+        }
+        memcpy(y[ibl].qs, q2, QK_K/4);
+
+    }
+}
+
+size_t quantize_iq2_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+    (void)hist;
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int row = 0; row < nrow; ++row) {
+        quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq2_xxs);
+    }
+    return nrow * nblock * sizeof(block_iq2_xxs);
+}
+
+size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+    (void)hist;
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int row = 0; row < nrow; ++row) {
+        quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq2_xs);
+    }
+    return nrow * nblock * sizeof(block_iq2_xs);
+}
+
index df5e7ae807f5ffdd39de5ac676633a6071913479..e5d1102304ba5b2f4c1b750cd5ce2e7e42877ade 100644 (file)
@@ -196,8 +196,6 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
 void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
 void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
 void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
-void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
-void quantize_row_iq2_xs_reference (const float * restrict x, block_iq2_xs  * restrict y, int k);
 
 void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
 void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
@@ -212,8 +210,6 @@ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
-void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
-void quantize_row_iq2_xs (const float * restrict x, void * restrict y, int k);
 
 // Dequantization
 void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
@@ -246,3 +242,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx,
 void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+
+//
+// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
+//
+size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+size_t quantize_q2_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+
diff --git a/ggml.c b/ggml.c
index bcfb6652c10326f5be23f0b59cb3ce8c80d72168..52467475a1f22d909cdb688e96a438b0b27bc20c 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -585,8 +585,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .type_size                = sizeof(block_iq2_xxs),
         .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_iq2_xxs,
-        .from_float               = quantize_row_iq2_xxs,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_iq2_xxs_reference,
+        .from_float               = NULL,
+        .from_float_reference     = NULL,
         .vec_dot                  = ggml_vec_dot_iq2_xxs_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
@@ -596,8 +596,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .type_size                = sizeof(block_iq2_xs),
         .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_iq2_xs,
-        .from_float               = quantize_row_iq2_xs,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_iq2_xs_reference,
+        .from_float               = NULL,
+        .from_float_reference     = NULL,
         .vec_dot                  = ggml_vec_dot_iq2_xs_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
@@ -18665,8 +18665,11 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t *
     return (n/QK8_0*sizeof(block_q8_0));
 }
 
-size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
+size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start,
+        int nrows, int n_per_row, int64_t * hist, const float * imatrix) {
+    (void)imatrix;
     size_t result = 0;
+    int n = nrows * n_per_row;
     switch (type) {
         case GGML_TYPE_Q4_0:
             {
@@ -18701,8 +18704,11 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
         case GGML_TYPE_Q2_K:
             {
                 GGML_ASSERT(start % QK_K == 0);
-                block_q2_K * block = (block_q2_K*)dst + start / QK_K;
-                result = ggml_quantize_q2_K(src + start, block, n, n, hist);
+                GGML_ASSERT(start % n_per_row == 0);
+                size_t start_row = start / n_per_row;
+                size_t row_size = ggml_row_size(type, n_per_row);
+                result = quantize_q2_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
+                GGML_ASSERT(result == row_size * nrows);
             } break;
         case GGML_TYPE_Q3_K:
             {
@@ -18731,14 +18737,22 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
         case GGML_TYPE_IQ2_XXS:
             {
                 GGML_ASSERT(start % QK_K == 0);
-                block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
-                result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
+                GGML_ASSERT(start % n_per_row == 0);
+                GGML_ASSERT(imatrix);
+                size_t start_row = start / n_per_row;
+                size_t row_size = ggml_row_size(type, n_per_row);
+                result = quantize_iq2_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
+                GGML_ASSERT(result == row_size * nrows);
             } break;
         case GGML_TYPE_IQ2_XS:
             {
                 GGML_ASSERT(start % QK_K == 0);
-                block_iq2_xs * block = (block_iq2_xs*)dst + start / QK_K;
-                result = ggml_quantize_iq2_xs(src + start, block, n, n, hist);
+                GGML_ASSERT(start % n_per_row == 0);
+                GGML_ASSERT(imatrix);
+                size_t start_row = start / n_per_row;
+                size_t row_size = ggml_row_size(type, n_per_row);
+                result = quantize_iq2_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
+                GGML_ASSERT(result == row_size * nrows);
             } break;
         case GGML_TYPE_F16:
             {
diff --git a/ggml.h b/ggml.h
index b18ba78120ca6d9aaec341529d1e351085326fcc..1187074f7f17420d57e8def001588770974b7af5 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -2067,10 +2067,13 @@ extern "C" {
     GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
-    GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
-    GGML_API size_t ggml_quantize_iq2_xs (const float * src, void * dst, int n, int k, int64_t * hist);
 
-    GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
+    GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst,
+            int start, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+
+    // These are needed for IQ2_XS and IQ2_XXS quantizations
+    GGML_API void ggml_init_iq2_quantization(enum ggml_type type);
+    GGML_API void ggml_deinit_iq2_quantization(enum ggml_type type);
 
     //
     // Importance matrix
index 8e20e72a232145a0baaa6ac196a17e62397c59d7..107b051147d3fcbacddf499917ceec1d2896b80a 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -8429,9 +8429,23 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
         if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
             new_type = GGML_TYPE_Q8_0;
         }
+        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS) {
+            new_type = GGML_TYPE_Q5_K;
+        }
         else if (new_type != GGML_TYPE_Q8_0) {
             new_type = GGML_TYPE_Q6_K;
         }
+    } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS) {
+        if (name.find("attn_v.weight") != std::string::npos) {
+            if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
+            else new_type = GGML_TYPE_Q2_K;
+            ++qs.i_attention_wv;
+        }
+        else if (name.find("ffn_down") != std::string::npos) {
+            if (qs.i_feed_forward_w2 < qs.n_feed_forward_w2/8) new_type = GGML_TYPE_Q2_K;
+            ++qs.i_feed_forward_w2;
+        }
+        else if (name == "token_embd.weight") new_type = GGML_TYPE_Q2_K;
     } else if (name.find("attn_v.weight") != std::string::npos) {
         if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
         else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
@@ -8601,6 +8615,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     if (params->only_copy) {
         ftype = model.ftype;
     }
+    const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
+    if (params->imatrix) {
+        imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix);
+        if (imatrix_data) {
+            printf("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
+        }
+    }
 
     const size_t align = GGUF_DEFAULT_ALIGNMENT;
     struct gguf_context * ctx_out = gguf_init_empty();
@@ -8658,6 +8679,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     // placeholder for the meta data
     ::zeros(fout, meta_size);
 
+    std::set<ggml_type> used_iq2;
+
     for (int i = 0; i < ml.n_tensors; ++i) {
         struct ggml_tensor * tensor = ml.get_tensor_meta(i);
 
@@ -8710,6 +8733,35 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         } else {
             const size_t nelements = ggml_nelements(tensor);
 
+            if ((new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_XS) && used_iq2.find(new_type) == used_iq2.end()) {
+                ggml_init_iq2_quantization(new_type);
+                used_iq2.insert(new_type);
+            }
+
+            const float * imatrix = nullptr;
+            if (imatrix_data) {
+                auto it = imatrix_data->find(tensor->name);
+                if (it == imatrix_data->end()) {
+                    printf("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
+                } else {
+                    if (it->second.size() == (size_t)tensor->ne[0]) {
+                        imatrix = it->second.data();
+                    } else {
+                        printf("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
+                                int(it->second.size()), int(tensor->ne[0]), tensor->name);
+                    }
+                }
+            }
+            if ((new_type == GGML_TYPE_IQ2_XXS ||
+                 new_type == GGML_TYPE_IQ2_XS  ||
+                (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
+                fprintf(stderr, "\n\n============================================================\n");
+                fprintf(stderr, "Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
+                fprintf(stderr, "The result will be garbage, so bailing out\n");
+                fprintf(stderr, "============================================================\n\n");
+                throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
+            }
+
             float * f32_data;
 
             if (tensor->type == GGML_TYPE_F32) {
@@ -8730,21 +8782,28 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             new_data = work.data();
             std::array<int64_t, 1 << 4> hist_cur = {};
 
-            static const int chunk_size = 32 * 512;
+            const int n_per_row = tensor->ne[0];
+            const int nrows = nelements / n_per_row;
+
+            static const int min_chunk_size = 32 * 512;
+            const int chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row);
+
             const int nchunk = (nelements + chunk_size - 1)/chunk_size;
             const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1;
             if (nthread_use < 2) {
-                new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nelements, hist_cur.data());
+                new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, hist_cur.data(), imatrix);
             } else {
-                size_t counter = 0;
+                int counter = 0;
                 new_size = 0;
-                auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, nelements]() {
+                auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, chunk_size,
+                     nrows, n_per_row, imatrix]() {
                     std::array<int64_t, 1 << 4> local_hist = {};
+                    const int nrows_per_chunk = chunk_size / n_per_row;
                     size_t local_size = 0;
                     while (true) {
                         std::unique_lock<std::mutex> lock(mutex);
-                        size_t first = counter; counter += chunk_size;
-                        if (first >= nelements) {
+                        int first_row = counter; counter += nrows_per_chunk;
+                        if (first_row >= nrows) {
                             if (local_size > 0) {
                                 for (int j=0; j<int(local_hist.size()); ++j) {
                                     hist_cur[j] += local_hist[j];
@@ -8754,8 +8813,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                             break;
                         }
                         lock.unlock();
-                        size_t last = std::min(nelements, first + chunk_size);
-                        local_size += ggml_quantize_chunk(new_type, f32_data, new_data, first, last - first, local_hist.data());
+                        const int this_nrow = std::min(nrows - first_row, nrows_per_chunk);
+                        local_size += ggml_quantize_chunk(new_type, f32_data, new_data,
+                                first_row * n_per_row, this_nrow, n_per_row, local_hist.data(), imatrix);
                     }
                 };
                 for (int it = 0; it < nthread_use - 1; ++it) {
@@ -8766,7 +8826,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                 workers.clear();
             }
 
-            LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
+            LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
             int64_t tot_count = 0;
             for (size_t i = 0; i < hist_cur.size(); i++) {
                 hist_all[i] += hist_cur[i];
@@ -8774,6 +8834,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             }
 
             if (tot_count > 0) {
+                LLAMA_LOG_INFO(" | hist: ");
                 for (size_t i = 0; i < hist_cur.size(); i++) {
                     LLAMA_LOG_INFO("%5.3f ", hist_cur[i] / float(nelements));
                 }
@@ -8802,6 +8863,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
     fout.close();
 
+    for (auto type : used_iq2) {
+        ggml_deinit_iq2_quantization(type);
+    }
+
     gguf_free(ctx_out);
 
     LLAMA_LOG_INFO("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
@@ -9166,6 +9231,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
         /*.quantize_output_tensor      =*/ true,
         /*.only_copy                   =*/ false,
         /*.pure                        =*/ false,
+        /*.imatrix                     =*/ nullptr,
     };
 
     return result;
diff --git a/llama.h b/llama.h
index 01d6fafaa4b0d2fecc4a60d673ddbde68f738ad5..79c8335b66bdfa2d77d91e03c3ae2f19d59632ad 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -249,6 +249,7 @@ extern "C" {
         bool quantize_output_tensor; // quantize output.weight
         bool only_copy;              // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
         bool pure;                   // disable k-quant mixtures and quantize all tensors to the same type
+        void * imatrix;              // pointer to importance matrix data
     } llama_model_quantize_params;
 
     // grammar types
index d9b8b106a6033a9efcde0fb9eb13788a88c1d811..22a7856d46f417f91cee1d8602a475a7a1650c39 100644 (file)
@@ -56,7 +56,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
         GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
         std::vector<uint8_t> dataq(ggml_row_size(tensor->type, size));
         int64_t hist[16];
-        ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size, hist);
+        ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], hist, nullptr);
         ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
     } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
         // This is going to create some weird integers though.