]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
iq2_xxs: tune quantization (llama/5320)
authorKawrakow <redacted>
Mon, 5 Feb 2024 08:46:06 +0000 (10:46 +0200)
committerGeorgi Gerganov <redacted>
Sat, 10 Feb 2024 07:55:46 +0000 (09:55 +0200)
We get slightly better PPL, and we cut quantization time in
nearly half.

The trick is to 1st quantize without forcing points onto the E8-lattice.
We can then use a narrower search range around the block scale that we
got that way.

Co-authored-by: Iwan Kawrakow <redacted>
ggml-quants.c

index 8236385bce8e9e300e06ac0b4fe789c5abe4e540..014c0525abd1bb74cb17b01f3b19927af2225066 100644 (file)
@@ -9048,8 +9048,6 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
     int8_t L[32];
     int8_t Laux[32];
     float  waux[32];
-    bool   is_on_grid[4];
-    bool   is_on_grid_aux[4];
     uint8_t block_signs[4];
     uint32_t q2[2*(QK_K/32)];
 
@@ -9099,10 +9097,11 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
                 memset(L, 0, 32);
                 continue;
             }
+            float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
+            float eff_max = scale*kMaxQ;
             float best = 0;
-            float scale = max/(2*kMaxQ-1);
-            for (int is = -9; is <= 9; ++is) {
-                float id = (2*kMaxQ-1+is*0.1f)/max;
+            for (int is = -6; is <= 6; ++is) {
+                float id = (2*kMaxQ-1+is*0.1f)/eff_max;
                 float this_scale = 1/id;
                 for (int k = 0; k < 4; ++k) {
                     for (int i = 0; i < 8; ++i) {
@@ -9112,9 +9111,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
                     uint16_t u = 0;
                     for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
                     int grid_index = kmap_q2xs[u];
-                    is_on_grid_aux[k] = true;
                     if (grid_index < 0) {
-                        is_on_grid_aux[k] = false;
                         const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
                         grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
                     }
@@ -9128,16 +9125,12 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
                 }
                 if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
                     scale = sumqx/sumq2; best = scale*sumqx;
-                    for (int i = 0; i < 32; ++i) L[i] = Laux[i];
-                    for (int k = 0; k <  4; ++k) is_on_grid[k] = is_on_grid_aux[k];
+                    memcpy(L, Laux, 32);
                 }
             }
-            int n_not_ongrid = 0;
-            for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
-            if (n_not_ongrid > 0 && scale > 0) {
+            if (scale > 0) {
                 float id = 1/scale;
                 for (int k = 0; k < 4; ++k) {
-                    if (is_on_grid[k]) continue;
                     uint16_t u = 0;
                     for (int i = 0; i < 8; ++i) {
                         int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
@@ -9193,49 +9186,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
         float d = max_scale/31;
         y[ibl].d = GGML_FP32_TO_FP16(d);
         float id = 1/d;
-        float sumqx = 0, sumq2 = 0;
         for (int ib = 0; ib < QK_K/32; ++ib) {
             int l = nearest_int(0.5f*(id*scales[ib]-1));
             l = MAX(0, MIN(15, l));
             q2[2*ib+1] |= ((uint32_t)l << 28);
-            const float * xb = xbl + 32*ib;
-            const float * qw = quant_weights + QK_K*ibl + 32*ib;
-            for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
-            const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
-            const float db = d * (1 + 2*l);
-            uint32_t u = 0;
-            for (int k = 0; k < 4; ++k) {
-                const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
-                const float * xk = xb + 8*k;
-                const float * wk = weight + 8*k;
-                const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
-                float best_mse = 0; int best_index = aux8[k];
-                for (int j = 0; j < 8; ++j) {
-                    float diff = db * grid[j] * signs[j] - xk[j];
-                    best_mse += wk[j] * diff * diff;
-                }
-                for (int idx = 0; idx < 256; ++idx) {
-                    grid = (const uint8_t *)(kgrid_q2xs + idx);
-                    float mse = 0;
-                    for (int j = 0; j < 8; ++j) {
-                        float diff = db * grid[j] * signs[j] - xk[j];
-                        mse += wk[j] * diff * diff;
-                    }
-                    if (mse < best_mse) {
-                        best_mse = mse; best_index = idx;
-                    }
-                }
-                u |= (best_index << 8*k);
-                grid = (const uint8_t *)(kgrid_q2xs + best_index);
-                //grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
-                for (int j = 0; j < 8; ++j) {
-                    float q = db * grid[j] * signs[j];
-                    sumqx += wk[j] * q * xk[j];
-                    sumq2 += wk[j] * q * q;
-                }
-            }
-            q2[2*ib] = u;
-            if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
         }
         memcpy(y[ibl].qs, q2, QK_K/4);
     }