]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
2-bit quantizations (llama/4897)
authorKawrakow <redacted>
Sun, 14 Jan 2024 07:45:56 +0000 (09:45 +0200)
committerGeorgi Gerganov <redacted>
Sun, 14 Jan 2024 08:54:09 +0000 (10:54 +0200)
* imatrix: load

* imatrix: WIP

* imatrix: Add Q2_K quantization

* imatrix: also guard against Q2_K_S quantization without importance matrix

* imatrix: guard even more against low-bit quantization misuse

---------

Co-authored-by: Iwan Kawrakow <redacted>
ggml-quants.c
ggml-quants.h
ggml.c
ggml.h

index 601d155d73696518854c04c316468090d3dd5b6e..9290d54cfba7a4d792f68663e26fbd2b5f46ba1f 100644 (file)
@@ -5,6 +5,8 @@
 #include <string.h>
 #include <assert.h>
 #include <float.h>
+#include <stdlib.h> // for qsort
+#include <stdio.h>  // for GGML_ASSERT
 
 #ifdef __ARM_NEON
 
@@ -1639,6 +1641,241 @@ size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n
     return (n/QK_K*sizeof(block_q2_K));
 }
 
+static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+        float rmin, float rdelta, int nstep, bool use_mad) {
+    float min = x[0];
+    float max = x[0];
+    float sum_w = weights ? weights[0] : x[0]*x[0];
+    float sum_x = sum_w * x[0];
+    for (int i = 1; i < n; ++i) {
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+        float w = weights ? weights[i] : x[i]*x[i];
+        sum_w += w;
+        sum_x += w * x[i];
+    }
+    if (min > 0) {
+        min = 0;
+    }
+    if (max <= min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = -min;
+        return 0.f;
+    }
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    float best_mad = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale*(x[i] - min));
+        L[i] = MAX(0, MIN(nmax, l));
+        float diff = scale * L[i] + min - x[i];
+        diff = use_mad ? fabsf(diff) : diff*diff;
+        float w = weights ? weights[i] : x[i]*x[i];
+        best_mad += w * diff;
+    }
+    if (nstep < 1) {
+        *the_min = -min;
+        return scale;
+    }
+    for (int is = 0; is <= nstep; ++is) {
+        iscale = (rmin + rdelta*is + nmax)/(max - min);
+        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            Laux[i] = l;
+            float w = weights ? weights[i] : x[i]*x[i];
+            sum_l  += w*l;
+            sum_l2 += w*l*l;
+            sum_xl += w*l*x[i];
+        }
+        float D = sum_w * sum_l2 - sum_l * sum_l;
+        if (D > 0) {
+            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+            if (this_min > 0) {
+                this_min = 0;
+                this_scale = sum_xl / sum_l2;
+            }
+            float mad = 0;
+            for (int i = 0; i < n; ++i) {
+                float diff = this_scale * Laux[i] + this_min - x[i];
+                diff = use_mad ? fabsf(diff) : diff*diff;
+                float w = weights ? weights[i] : x[i]*x[i];
+                mad += w * diff;
+            }
+            if (mad < best_mad) {
+                for (int i = 0; i < n; ++i) {
+                    L[i] = Laux[i];
+                }
+                best_mad = mad;
+                scale = this_scale;
+                min = this_min;
+            }
+        }
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) {
+    float max = 0;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+    }
+    if (!max) { // all zero
+        for (int i = 0; i < n; ++i) { L[i] = 0; }
+        return 0.f;
+    }
+    float iscale = nmax / max;
+    for (int i = 0; i < n; ++i) {
+        L[i] = nearest_int(iscale * x[i]);
+    }
+    float scale = 1/iscale;
+    float best_mse = 0;
+    for (int i = 0; i < n; ++i) {
+        float diff = x[i] - scale*L[i];
+        float w = quant_weights[i];
+        best_mse += w*diff*diff;
+    }
+    for (int is = -4; is <= 4; ++is) {
+        if (is == 0) continue;
+        float iscale_is = (0.1f*is + nmax)/max;
+        float scale_is = 1/iscale_is;
+        float mse = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale_is*x[i]);
+            l = MIN(nmax, l);
+            float diff = x[i] - scale_is*l;
+            float w = quant_weights[i];
+            mse += w*diff*diff;
+        }
+        if (mse < best_mse) {
+            best_mse = mse;
+            iscale = iscale_is;
+        }
+    }
+    float sumlx = 0;
+    float suml2 = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MIN(nmax, l);
+        L[i] = l;
+        float w = quant_weights[i];
+        sumlx += w*x[i]*l;
+        suml2 += w*l*l;
+    }
+    for (int itry = 0; itry < 5; ++itry) {
+        int n_changed = 0;
+        for (int i = 0; i < n; ++i) {
+            float w = quant_weights[i];
+            float slx = sumlx - w*x[i]*L[i];
+            float sl2 = suml2 - w*L[i]*L[i];
+            if (slx > 0 && sl2 > 0) {
+                int new_l = nearest_int(x[i] * sl2 / slx);
+                new_l = MIN(nmax, new_l);
+                if (new_l != L[i]) {
+                    slx += w*x[i]*new_l;
+                    sl2 += w*new_l*new_l;
+                    if (slx*slx*suml2 > sumlx*sumlx*sl2) {
+                        L[i] = new_l; sumlx = slx; suml2 = sl2;
+                        ++n_changed;
+                    }
+                }
+            }
+        }
+        if (!n_changed) {
+            break;
+        }
+    }
+    return sumlx / suml2;
+}
+
+static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) {
+    GGML_ASSERT(quant_weights);
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+    const bool requantize = true;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[16];
+    float mins[QK_K/16];
+    float scales[QK_K/16];
+    float sw[QK_K/16];
+    float weight[QK_K/16];
+    uint8_t Ls[QK_K/16], Lm[QK_K/16];
+
+    for (int i = 0; i < nb; i++) {
+        memset(sw, 0, QK_K/16*sizeof(float));
+        float sumx2 = 0;
+        for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
+        float sigma2 = sumx2/QK_K;
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float * restrict qw = quant_weights + QK_K * i + 16*j;
+            for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
+            for (int l = 0; l < 16; ++l) sw[j] += weight[l];
+            scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
+        }
+
+        float dm  = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
+        float mm  = make_qp_quants(QK_K/16, 15, mins,   Lm, sw);
+        y[i].d    = GGML_FP32_TO_FP16(dm);
+        y[i].dmin = GGML_FP32_TO_FP16(mm);
+        dm        = GGML_FP16_TO_FP32(y[i].d);
+        mm        = GGML_FP16_TO_FP32(y[i].dmin);
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            y[i].scales[j] = Ls[j] | (Lm[j] << 4);
+        }
+
+        if (requantize) {
+            for (int j = 0; j < QK_K/16; ++j) {
+                const float d = dm * (y[i].scales[j] & 0xF);
+                if (!d) continue;
+                const float m = mm * (y[i].scales[j] >> 4);
+                for (int ii = 0; ii < 16; ++ii) {
+                    int l = nearest_int((x[16*j + ii] + m)/d);
+                    l = MAX(0, MIN(3, l));
+                    L[16*j + ii] = l;
+                }
+            }
+        }
+
+#if QK_K == 256
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+#else
+        for (int l = 0; l < 16; ++l) {
+            y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6);
+        }
+#endif
+
+        x += QK_K;
+
+    }
+}
+
+size_t quantize_q2_K(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+    (void)hist;
+    int row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
+    if (!quant_weights) {
+        quantize_row_q2_K_reference(src, dst, nrow*n_per_row);
+    }
+    else {
+        char * qrow = (char *)dst;
+        for (int row = 0; row < nrow; ++row) {
+            quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);
+            src += n_per_row;
+            qrow += row_size;
+        }
+    }
+    return nrow * row_size;
+}
+
 //========================= 3-bit (de)-quantization
 
 void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
@@ -2584,14 +2821,6 @@ static const uint8_t ksigns_iq2xs[128] = {
 
 static const uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
 
-void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k) {
-    (void)x;
-    (void)y;
-    (void)k;
-    assert(k % QK_K == 0);
-    //fprintf(stderr, "=========================== %s: not implemented\n", __func__);
-}
-
 void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
@@ -2618,33 +2847,8 @@ void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y
     }
 }
 
-void quantize_row_iq2_xxs(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK_K == 0);
-    block_iq2_xxs * restrict y = vy;
-    quantize_row_iq2_xxs_reference(x, y, k);
-}
-
-size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist) {
-    assert(k % QK_K == 0);
-    (void)hist; // TODO: collect histograms
-
-    for (int j = 0; j < n; j += k) {
-        block_iq2_xxs * restrict y = (block_iq2_xxs *)dst + j/QK_K;
-        quantize_row_iq2_xxs_reference(src + j, y, k);
-    }
-    return (n/QK_K*sizeof(block_iq2_xxs));
-}
-
 // ====================== 2.3125 bpw (de)-quantization
 
-void quantize_row_iq2_xs_reference(const float * restrict x, block_iq2_xs * restrict y, int k) {
-    (void)x;
-    (void)y;
-    (void)k;
-    assert(k % QK_K == 0);
-    //fprintf(stderr, "=========================== %s: not implemented\n", __func__);
-}
-
 void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int k) {
     assert(k % QK_K == 0);
     const int nb = k / QK_K;
@@ -2670,23 +2874,6 @@ void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y,
     }
 }
 
-void quantize_row_iq2_xs(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK_K == 0);
-    block_iq2_xs * restrict y = vy;
-    quantize_row_iq2_xs_reference(x, y, k);
-}
-
-size_t ggml_quantize_iq2_xs(const float * src, void * dst, int n, int k, int64_t * hist) {
-    assert(k % QK_K == 0);
-    (void)hist; // TODO: collect histograms
-
-    for (int j = 0; j < n; j += k) {
-        block_iq2_xs * restrict y = (block_iq2_xs *)dst + j/QK_K;
-        quantize_row_iq2_xs_reference(src + j, y, k);
-    }
-    return (n/QK_K*sizeof(block_iq2_xs));
-}
-
 //===================================== Q8_K ==============================================
 
 void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@@ -7730,3 +7917,666 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
     *s = 0.125f * sumf;
 #endif
 }
+
+// ================================ IQ2 quantization =============================================
+
+typedef struct {
+    uint64_t * grid;
+    int      * map;
+    uint16_t * neighbours;
+} iq2_entry_t;
+
+static iq2_entry_t iq2_data[2] = {
+    {NULL, NULL, NULL},
+    {NULL, NULL, NULL},
+};
+
+static inline int iq2_data_index(int grid_size) {
+    GGML_ASSERT(grid_size == 256 || grid_size == 512);
+    return grid_size == 256 ? 0 : 1;
+}
+
+static int iq2_compare_func(const void * left, const void * right) {
+    const int * l = (const int *)left;
+    const int * r = (const int *)right;
+    return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
+}
+
+static void q2xs_init_impl(int grid_size) {
+    const int gindex = iq2_data_index(grid_size);
+    if (iq2_data[gindex].grid) {
+        return;
+    }
+    static const uint16_t kgrid_256[256] = {
+            0,     2,     5,     8,    10,    17,    20,    32,    34,    40,    42,    65,    68,    80,    88,    97,
+          100,   128,   130,   138,   162,   257,   260,   272,   277,   320,   388,   408,   512,   514,   546,   642,
+         1025,  1028,  1040,  1057,  1060,  1088,  1090,  1096,  1120,  1153,  1156,  1168,  1188,  1280,  1282,  1288,
+         1312,  1350,  1385,  1408,  1425,  1545,  1552,  1600,  1668,  1700,  2048,  2053,  2056,  2068,  2088,  2113,
+         2116,  2128,  2130,  2184,  2308,  2368,  2562,  2580,  4097,  4100,  4112,  4129,  4160,  4192,  4228,  4240,
+         4245,  4352,  4360,  4384,  4432,  4442,  4480,  4644,  4677,  5120,  5128,  5152,  5157,  5193,  5248,  5400,
+         5474,  5632,  5654,  6145,  6148,  6160,  6208,  6273,  6400,  6405,  6560,  6737,  8192,  8194,  8202,  8260,
+         8289,  8320,  8322,  8489,  8520,  8704,  8706,  9217,  9220,  9232,  9280,  9302,  9472,  9537,  9572,  9872,
+        10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516,
+        16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561,
+        17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488,
+        20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545,
+        22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874,
+        25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856,
+        33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
+        37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
+    };
+    static const uint16_t kgrid_512[512] = {
+            0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,
+           73,    80,    82,    85,    88,    97,   100,   128,   130,   133,   136,   145,   148,   153,   160,   257,
+          260,   262,   265,   272,   274,   277,   280,   282,   289,   292,   320,   322,   325,   328,   337,   340,
+          352,   360,   385,   388,   400,   512,   514,   517,   520,   529,   532,   544,   577,   580,   592,   597,
+          640,   650,  1025,  1028,  1030,  1033,  1040,  1042,  1045,  1048,  1057,  1060,  1088,  1090,  1093,  1096,
+         1105,  1108,  1110,  1120,  1153,  1156,  1168,  1280,  1282,  1285,  1288,  1297,  1300,  1312,  1345,  1348,
+         1360,  1377,  1408,  1537,  1540,  1552,  1574,  1600,  1602,  1668,  2048,  2050,  2053,  2056,  2058,  2065,
+         2068,  2080,  2085,  2113,  2116,  2128,  2136,  2176,  2208,  2218,  2305,  2308,  2320,  2368,  2433,  2441,
+         2560,  2592,  2600,  2710,  2720,  4097,  4100,  4102,  4105,  4112,  4114,  4117,  4120,  4129,  4132,  4160,
+         4162,  4165,  4168,  4177,  4180,  4192,  4202,  4225,  4228,  4240,  4352,  4354,  4357,  4360,  4369,  4372,
+         4384,  4417,  4420,  4432,  4480,  4500,  4502,  4609,  4612,  4614,  4624,  4672,  4704,  5120,  5122,  5125,
+         5128,  5137,  5140,  5152,  5185,  5188,  5193,  5200,  5220,  5248,  5377,  5380,  5392,  5440,  5632,  5652,
+         5705,  6145,  6148,  6160,  6162,  6208,  6228,  6278,  6400,  6405,  6502,  6737,  6825,  8192,  8194,  8197,
+         8200,  8202,  8209,  8212,  8224,  8257,  8260,  8272,  8320,  8352,  8449,  8452,  8464,  8512,  8520,  8549,
+         8704,  8738,  8832,  8872,  9217,  9220,  9232,  9257,  9280,  9472,  9537,  9554,  9625,  9729,  9754,  9894,
+        10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388,
+        16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480,
+        16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773,
+        16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473,
+        17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436,
+        18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497,
+        20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162,
+        21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528,
+        22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745,
+        24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234,
+        32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025,
+        33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810,
+        33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984,
+        35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462,
+        37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960,
+        40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
+        42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
+    };
+    const int kmap_size = 43692;
+    const int nwant = 2;
+    const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
+    uint64_t * kgrid_q2xs;
+    int      * kmap_q2xs;
+    uint16_t * kneighbors_q2xs;
+
+    printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
+    uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t));
+    for (int k = 0; k < grid_size; ++k) {
+        int8_t * pos = (int8_t *)(the_grid + k);
+        for (int i = 0; i < 8; ++i) {
+            int l = (kgrid[k] >> 2*i) & 0x3;
+            pos[i] = 2*l + 1;
+        }
+    }
+    kgrid_q2xs = the_grid;
+    iq2_data[gindex].grid = the_grid;
+    kmap_q2xs = (int *)malloc(kmap_size*sizeof(int));
+    iq2_data[gindex].map = kmap_q2xs;
+    for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1;
+    uint64_t aux64;
+    uint8_t * aux8 = (uint8_t *)&aux64;
+    for (int i = 0; i < grid_size; ++i) {
+        aux64 = kgrid_q2xs[i];
+        uint16_t index = 0;
+        for (int k=0; k<8; ++k) {
+            uint16_t q = (aux8[k] - 1)/2;
+            index |= (q << 2*k);
+        }
+        kmap_q2xs[index] = i;
+    }
+    int8_t pos[8];
+    int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
+    int num_neighbors = 0, num_not_in_map = 0;
+    for (int i = 0; i < kmap_size; ++i) {
+        if (kmap_q2xs[i] >= 0) continue;
+        ++num_not_in_map;
+        for (int k = 0; k < 8; ++k) {
+            int l = (i >> 2*k) & 0x3;
+            pos[k] = 2*l + 1;
+        }
+        for (int j = 0; j < grid_size; ++j) {
+            const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
+            int d2 = 0;
+            for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+            dist2[2*j+0] = d2;
+            dist2[2*j+1] = j;
+        }
+        qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
+        int n = 0; int d2 = dist2[0];
+        int nhave = 1;
+        for (int j = 0; j < grid_size; ++j) {
+            if (dist2[2*j] > d2) {
+                if (nhave == nwant) break;
+                d2 = dist2[2*j];
+                ++nhave;
+            }
+            ++n;
+        }
+        num_neighbors += n;
+    }
+    printf("%s: %d neighbours in total\n", __func__, num_neighbors);
+    kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
+    iq2_data[gindex].neighbours = kneighbors_q2xs;
+    int counter = 0;
+    for (int i = 0; i < kmap_size; ++i) {
+        if (kmap_q2xs[i] >= 0) continue;
+        for (int k = 0; k < 8; ++k) {
+            int l = (i >> 2*k) & 0x3;
+            pos[k] = 2*l + 1;
+        }
+        for (int j = 0; j < grid_size; ++j) {
+            const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
+            int d2 = 0;
+            for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+            dist2[2*j+0] = d2;
+            dist2[2*j+1] = j;
+        }
+        qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
+        kmap_q2xs[i] = -(counter + 1);
+        int d2 = dist2[0];
+        uint16_t * start = &kneighbors_q2xs[counter++];
+        int n = 0, nhave = 1;
+        for (int j = 0; j < grid_size; ++j) {
+            if (dist2[2*j] > d2) {
+                if (nhave == nwant) break;
+                d2 = dist2[2*j];
+                ++nhave;
+            }
+            kneighbors_q2xs[counter++] = dist2[2*j+1];
+            ++n;
+        }
+        *start = n;
+    }
+    free(dist2);
+}
+
+void ggml_init_iq2_quantization(enum ggml_type type) {
+    if (type == GGML_TYPE_IQ2_XXS) {
+        q2xs_init_impl(256);
+    }
+    else if (type == GGML_TYPE_IQ2_XS) {
+        q2xs_init_impl(512);
+    }
+    else {
+        fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
+    }
+}
+
+static void q2xs_deinit_impl(int grid_size) {
+    GGML_ASSERT(grid_size == 256 || grid_size == 512 || grid_size == 1024);
+    const int gindex = iq2_data_index(grid_size);
+    if (iq2_data[gindex].grid) {
+        free(iq2_data[gindex].grid);       iq2_data[gindex].grid = NULL;
+        free(iq2_data[gindex].map);        iq2_data[gindex].map  = NULL;
+        free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL;
+    }
+}
+
+void ggml_deinit_iq2_quantization(enum ggml_type type) {
+    if (type == GGML_TYPE_IQ2_XXS) {
+        q2xs_deinit_impl(256);
+    }
+    else if (type == GGML_TYPE_IQ2_XS) {
+        q2xs_deinit_impl(512);
+    }
+    else {
+        fprintf(stderr, "======================== Why are you calling %s with type %d?\n", __func__, (int)type);
+    }
+}
+
+static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
+        const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
+    int num_neighbors = neighbours[0];
+    GGML_ASSERT(num_neighbors > 0);
+    float best_d2 = FLT_MAX;
+    int grid_index = -1;
+    for (int j = 1; j <= num_neighbors; ++j) {
+        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+        float d2 = 0;
+        for (int i = 0; i < 8; ++i) {
+            float q = pg[i];
+            float diff = scale*q - xval[i];
+            d2 += weight[i]*diff*diff;
+        }
+        if (d2 < best_d2) {
+            best_d2 = d2; grid_index = neighbours[j];
+        }
+    }
+    GGML_ASSERT(grid_index >= 0);
+    const int8_t * pg = (const int8_t *)(grid + grid_index);
+    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
+    return grid_index;
+}
+
+static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+
+    const int gindex = iq2_data_index(256);
+
+    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
+    const int      * kmap_q2xs       = iq2_data[gindex].map;
+    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+    GGML_ASSERT(quant_weights);
+    GGML_ASSERT(kgrid_q2xs);
+    GGML_ASSERT(kmap_q2xs);
+    GGML_ASSERT(kneighbors_q2xs);
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int kMaxQ = 3;
+
+    const int nbl = n/256;
+
+    block_iq2_xxs * y = vy;
+
+    float scales[QK_K/32];
+    float weight[32];
+    float xval[32];
+    int8_t L[32];
+    int8_t Laux[32];
+    float  waux[32];
+    bool   is_on_grid[4];
+    bool   is_on_grid_aux[4];
+    uint8_t block_signs[4];
+    uint32_t q2[2*(QK_K/32)];
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+        memset(q2, 0, QK_K/4);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const float * xb = xbl + 32*ib;
+            const float * qw = quant_weights + QK_K*ibl + 32*ib;
+            for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
+            for (int k = 0; k < 4; ++k) {
+                int nflip = 0;
+                uint8_t s = 0;
+                for (int i = 0; i < 8; ++i) {
+                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+                    else {
+                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+                    }
+                }
+                if (nflip%2) {
+                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+                    for (int i = 1; i < 8; ++i) {
+                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+                        if (ax < min) {
+                            min = ax; imin = i;
+                        }
+                    }
+                    xval[8*k+imin] = -xval[8*k+imin];
+                    s ^= (1 << imin);
+                }
+                block_signs[k] = s & 127;
+            }
+            float max = xval[0];
+            for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
+            if (!max) {
+                scales[ib] = 0;
+                memset(L, 0, 32);
+                continue;
+            }
+            float best = 0;
+            float scale = max/(2*kMaxQ-1);
+            for (int is = -9; is <= 9; ++is) {
+                float id = (2*kMaxQ-1+is*0.1f)/max;
+                float this_scale = 1/id;
+                for (int k = 0; k < 4; ++k) {
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
+                    }
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
+                    int grid_index = kmap_q2xs[u];
+                    is_on_grid_aux[k] = true;
+                    if (grid_index < 0) {
+                        is_on_grid_aux[k] = false;
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 32; ++i) {
+                    float w = weight[i];
+                    float q = 2*Laux[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                    scale = sumqx/sumq2; best = scale*sumqx;
+                    for (int i = 0; i < 32; ++i) L[i] = Laux[i];
+                    for (int k = 0; k <  4; ++k) is_on_grid[k] = is_on_grid_aux[k];
+                }
+            }
+            int n_not_ongrid = 0;
+            for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+            if (n_not_ongrid > 0 && scale > 0) {
+                float id = 1/scale;
+                for (int k = 0; k < 4; ++k) {
+                    if (is_on_grid[k]) continue;
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        l = MAX(0, MIN(kMaxQ-1, l));
+                        u |= (l << 2*i);
+                    }
+                    int grid_index = kmap_q2xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
+                    }
+                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index);
+                    for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2;
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 32; ++i) {
+                    float w = weight[i];
+                    float q = 2*L[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0) scale = sumqx/sumq2;
+            }
+            if (scale < 0) {
+                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
+                // and correspondingly flip quant signs.
+                scale = -scale;
+                for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
+            }
+            for (int k = 0; k < 4; ++k) {
+                uint16_t u = 0;
+                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
+                int grid_index = kmap_q2xs[u];
+                if (grid_index < 0) {
+                    printf("Oops: found point %u not on grid:", u);
+                    for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
+                    printf("\n");
+                    GGML_ASSERT(false);
+                }
+                q2[2*ib+0] |= (grid_index << 8*k);
+                q2[2*ib+1] |= (block_signs[k] << 7*k);
+            }
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            memset(y[ibl].qs, 0, QK_K/4);
+            continue;
+        }
+
+        float d = max_scale/31;
+        y[ibl].d = GGML_FP32_TO_FP16(d);
+        float id = 1/d;
+        float sumqx = 0, sumq2 = 0;
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib]-1));
+            l = MAX(0, MIN(15, l));
+            q2[2*ib+1] |= ((uint32_t)l << 28);
+            const float * xb = xbl + 32*ib;
+            const float * qw = quant_weights + QK_K*ibl + 32*ib;
+            for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
+            const float db = d * (1 + 2*l);
+            uint32_t u = 0;
+            for (int k = 0; k < 4; ++k) {
+                const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
+                const float * xk = xb + 8*k;
+                const float * wk = weight + 8*k;
+                const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
+                float best_mse = 0; int best_index = aux8[k];
+                for (int j = 0; j < 8; ++j) {
+                    float diff = db * grid[j] * signs[j] - xk[j];
+                    best_mse += wk[j] * diff * diff;
+                }
+                for (int idx = 0; idx < 256; ++idx) {
+                    grid = (const uint8_t *)(kgrid_q2xs + idx);
+                    float mse = 0;
+                    for (int j = 0; j < 8; ++j) {
+                        float diff = db * grid[j] * signs[j] - xk[j];
+                        mse += wk[j] * diff * diff;
+                    }
+                    if (mse < best_mse) {
+                        best_mse = mse; best_index = idx;
+                    }
+                }
+                u |= (best_index << 8*k);
+                grid = (const uint8_t *)(kgrid_q2xs + best_index);
+                //grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
+                for (int j = 0; j < 8; ++j) {
+                    float q = db * grid[j] * signs[j];
+                    sumqx += wk[j] * q * xk[j];
+                    sumq2 += wk[j] * q * q;
+                }
+            }
+            q2[2*ib] = u;
+            if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
+        }
+        memcpy(y[ibl].qs, q2, QK_K/4);
+    }
+}
+
+static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
+
+    const int gindex = iq2_data_index(512);
+
+    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
+    const int      * kmap_q2xs       = iq2_data[gindex].map;
+    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+    GGML_ASSERT(quant_weights);
+    GGML_ASSERT(kmap_q2xs);
+    GGML_ASSERT(kgrid_q2xs);
+    GGML_ASSERT(kneighbors_q2xs);
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int kMaxQ = 3;
+
+    const int nbl = n/256;
+
+    block_iq2_xs * y = vy;
+
+    float scales[QK_K/16];
+    float weight[16];
+    float xval[16];
+    int8_t L[16];
+    int8_t Laux[16];
+    float  waux[16];
+    bool   is_on_grid[2];
+    bool   is_on_grid_aux[2];
+    uint8_t block_signs[2];
+    uint16_t q2[2*(QK_K/16)];
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+        memset(q2, 0, QK_K/4);
+        memset(y[ibl].scales, 0, QK_K/32);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            const float * xb = xbl + 16*ib;
+            const float * qw = quant_weights + QK_K*ibl + 16*ib;
+            for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
+            for (int k = 0; k < 2; ++k) {
+                int nflip = 0;
+                uint8_t s = 0;
+                for (int i = 0; i < 8; ++i) {
+                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+                    else {
+                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+                    }
+                }
+                if (nflip%2) {
+                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+                    for (int i = 1; i < 8; ++i) {
+                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+                        if (ax < min) {
+                            min = ax; imin = i;
+                        }
+                    }
+                    xval[8*k+imin] = -xval[8*k+imin];
+                    s ^= (1 << imin);
+                }
+                block_signs[k] = s & 127;
+            }
+            float max = xval[0];
+            for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
+            if (!max) {
+                scales[ib] = 0;
+                memset(L, 0, 16);
+                continue;
+            }
+            float best = 0;
+            float scale = max/(2*kMaxQ-1);
+            is_on_grid[0] = is_on_grid[1] = true;
+            for (int is = -9; is <= 9; ++is) {
+                float id = (2*kMaxQ-1+is*0.1f)/max;
+                float this_scale = 1/id;
+                for (int k = 0; k < 2; ++k) {
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
+                    }
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
+                    int grid_index = kmap_q2xs[u];
+                    is_on_grid_aux[k] = true;
+                    if (grid_index < 0) {
+                        is_on_grid_aux[k] = false;
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 16; ++i) {
+                    float w = weight[i];
+                    float q = 2*Laux[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                    scale = sumqx/sumq2; best = scale*sumqx;
+                    for (int i = 0; i < 16; ++i) L[i] = Laux[i];
+                    for (int k = 0; k <  2; ++k) is_on_grid[k] = is_on_grid_aux[k];
+                }
+            }
+            int n_not_ongrid = 0;
+            for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+            if (n_not_ongrid > 0 && scale > 0) {
+                float id = 1/scale;
+                for (int k = 0; k < 2; ++k) {
+                    if (is_on_grid[k]) continue;
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        l = MAX(0, MIN(kMaxQ-1, l));
+                        u |= (l << 2*i);
+                        L[8*k + i] = l;
+                    }
+                    int grid_index = kmap_q2xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 16; ++i) {
+                    float w = weight[i];
+                    float q = 2*L[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0) scale = sumqx/sumq2;
+            }
+            if (scale < 0) {
+                scale = -scale;
+                for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127;
+            }
+            for (int k = 0; k < 2; ++k) {
+                uint16_t u = 0;
+                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
+                int grid_index = kmap_q2xs[u];
+                if (grid_index < 0) {
+                    printf("Oops: found point %u not on grid:", u);
+                    for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
+                    printf("\n");
+                    GGML_ASSERT(false);
+                }
+                q2[2*ib+k] = grid_index | (block_signs[k] << 9);
+            }
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            memset(y[ibl].qs, 0, QK_K/4);
+            continue;
+        }
+
+        float d = max_scale/31;
+        y[ibl].d = GGML_FP32_TO_FP16(d);
+        float id = 1/d;
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib]-1));
+            l = MAX(0, MIN(15, l));
+            if (ib%2 == 0) y[ibl].scales[ib/2] = l;
+            else y[ibl].scales[ib/2] |= (l << 4);
+        }
+        memcpy(y[ibl].qs, q2, QK_K/4);
+
+    }
+}
+
+size_t quantize_iq2_xxs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+    (void)hist;
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int row = 0; row < nrow; ++row) {
+        quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq2_xxs);
+    }
+    return nrow * nblock * sizeof(block_iq2_xxs);
+}
+
+size_t quantize_iq2_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
+    (void)hist;
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int row = 0; row < nrow; ++row) {
+        quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq2_xs);
+    }
+    return nrow * nblock * sizeof(block_iq2_xs);
+}
+
index df5e7ae807f5ffdd39de5ac676633a6071913479..e5d1102304ba5b2f4c1b750cd5ce2e7e42877ade 100644 (file)
@@ -196,8 +196,6 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict
 void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
 void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
 void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
-void quantize_row_iq2_xxs_reference(const float * restrict x, block_iq2_xxs * restrict y, int k);
-void quantize_row_iq2_xs_reference (const float * restrict x, block_iq2_xs  * restrict y, int k);
 
 void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
 void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
@@ -212,8 +210,6 @@ void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
 void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
-void quantize_row_iq2_xxs(const float * restrict x, void * restrict y, int k);
-void quantize_row_iq2_xs (const float * restrict x, void * restrict y, int k);
 
 // Dequantization
 void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
@@ -246,3 +242,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx,
 void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+
+//
+// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
+//
+size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+size_t quantize_q2_K   (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+
diff --git a/ggml.c b/ggml.c
index bcfb6652c10326f5be23f0b59cb3ce8c80d72168..52467475a1f22d909cdb688e96a438b0b27bc20c 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -585,8 +585,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .type_size                = sizeof(block_iq2_xxs),
         .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_iq2_xxs,
-        .from_float               = quantize_row_iq2_xxs,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_iq2_xxs_reference,
+        .from_float               = NULL,
+        .from_float_reference     = NULL,
         .vec_dot                  = ggml_vec_dot_iq2_xxs_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
@@ -596,8 +596,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .type_size                = sizeof(block_iq2_xs),
         .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_iq2_xs,
-        .from_float               = quantize_row_iq2_xs,
-        .from_float_reference     = (ggml_from_float_t) quantize_row_iq2_xs_reference,
+        .from_float               = NULL,
+        .from_float_reference     = NULL,
         .vec_dot                  = ggml_vec_dot_iq2_xs_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
@@ -18665,8 +18665,11 @@ size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t *
     return (n/QK8_0*sizeof(block_q8_0));
 }
 
-size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
+size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start,
+        int nrows, int n_per_row, int64_t * hist, const float * imatrix) {
+    (void)imatrix;
     size_t result = 0;
+    int n = nrows * n_per_row;
     switch (type) {
         case GGML_TYPE_Q4_0:
             {
@@ -18701,8 +18704,11 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
         case GGML_TYPE_Q2_K:
             {
                 GGML_ASSERT(start % QK_K == 0);
-                block_q2_K * block = (block_q2_K*)dst + start / QK_K;
-                result = ggml_quantize_q2_K(src + start, block, n, n, hist);
+                GGML_ASSERT(start % n_per_row == 0);
+                size_t start_row = start / n_per_row;
+                size_t row_size = ggml_row_size(type, n_per_row);
+                result = quantize_q2_K(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
+                GGML_ASSERT(result == row_size * nrows);
             } break;
         case GGML_TYPE_Q3_K:
             {
@@ -18731,14 +18737,22 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
         case GGML_TYPE_IQ2_XXS:
             {
                 GGML_ASSERT(start % QK_K == 0);
-                block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
-                result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
+                GGML_ASSERT(start % n_per_row == 0);
+                GGML_ASSERT(imatrix);
+                size_t start_row = start / n_per_row;
+                size_t row_size = ggml_row_size(type, n_per_row);
+                result = quantize_iq2_xxs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
+                GGML_ASSERT(result == row_size * nrows);
             } break;
         case GGML_TYPE_IQ2_XS:
             {
                 GGML_ASSERT(start % QK_K == 0);
-                block_iq2_xs * block = (block_iq2_xs*)dst + start / QK_K;
-                result = ggml_quantize_iq2_xs(src + start, block, n, n, hist);
+                GGML_ASSERT(start % n_per_row == 0);
+                GGML_ASSERT(imatrix);
+                size_t start_row = start / n_per_row;
+                size_t row_size = ggml_row_size(type, n_per_row);
+                result = quantize_iq2_xs(src + start, (char *)dst + start_row * row_size, nrows, n_per_row, hist, imatrix);
+                GGML_ASSERT(result == row_size * nrows);
             } break;
         case GGML_TYPE_F16:
             {
diff --git a/ggml.h b/ggml.h
index b18ba78120ca6d9aaec341529d1e351085326fcc..1187074f7f17420d57e8def001588770974b7af5 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -2067,10 +2067,13 @@ extern "C" {
     GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);
-    GGML_API size_t ggml_quantize_iq2_xxs(const float * src, void * dst, int n, int k, int64_t * hist);
-    GGML_API size_t ggml_quantize_iq2_xs (const float * src, void * dst, int n, int k, int64_t * hist);
 
-    GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
+    GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst,
+            int start, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
+
+    // These are needed for IQ2_XS and IQ2_XXS quantizations
+    GGML_API void ggml_init_iq2_quantization(enum ggml_type type);
+    GGML_API void ggml_deinit_iq2_quantization(enum ggml_type type);
 
     //
     // Importance matrix