]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Make IQ1_M work for QK_K = 64 (llama/6327)
authorKawrakow <redacted>
Wed, 27 Mar 2024 07:44:27 +0000 (08:44 +0100)
committerGeorgi Gerganov <redacted>
Wed, 27 Mar 2024 11:20:00 +0000 (13:20 +0200)
* iq1_m: make it work for QK_K = 64 (WIP)

* iq1_m: make it work for QK_K = 64 (scalar and AVX2)

* iq1_m: QK_K = 64 seems to work on Metal and ARM_NEON

---------

Co-authored-by: Iwan Kawrakow <redacted>
src/ggml-common.h
src/ggml-metal.metal
src/ggml-quants.c

index 517c9bb43b380ca15c7d9959f9cabd37c48589f3..b2d67d5db529ccf7c8df056c81703b69c14a7143 100644 (file)
@@ -377,13 +377,20 @@ typedef struct {
 } block_iq1_s;
 static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
 
-// 1.8125 bpw
+// 1.75 bpw
 typedef struct {
     uint8_t  qs[QK_K/8];      // grid index, low 8 bits
     uint8_t  qh[QK_K/16];     // grid index, high 3 bits + grid shift bit (for two groups of 8)
-    uint8_t  scales[QK_K/32]; // 4-bit block scales
+#if QK_K == 64
+    ggml_half d;
+#endif
+    uint8_t  scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64)
 } block_iq1_m;
+#if QK_K == 64
+static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32 + sizeof(ggml_half), "wrong iq1_m block size/padding");
+#else
 static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
+#endif
 
 // Used by IQ1_M quants
 typedef union {
index e8083734ca4dfa069cd57c2e0e45e79a7c2c2c42..744b2a8b4ce42c3e5982aceb478617ad7e1258a5 100644 (file)
@@ -4497,7 +4497,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
 
     device const float * y4 = y + 32 * ix;
 
+#if QK_K != 64
     iq1m_scale_t scale;
+#endif
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
 
@@ -4519,7 +4521,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
 
         for (int row = 0; row < N_DST; row++) {
 
+#if QK_K != 64
             scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+#endif
 
             constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
             constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
@@ -4535,8 +4539,14 @@ void kernel_mul_mv_iq1_m_f32_impl(
             }
             const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
             const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+#if QK_K == 64
+            const float d = (float) *((device const half *)(sc - 1));
+            sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
+                              (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
+#else
             sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
                                              (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
+#endif
 
             sc += nb*sizeof(block_iq1_m)/2;
             qs += nb*sizeof(block_iq1_m);
@@ -5277,13 +5287,21 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 &
     // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
     const int ib32 = il/2;
     il = il%2;
-    iq1m_scale_t scale;
     device const uint16_t * sc = (device const uint16_t *)xb->scales;
+#if QK_K == 64
+    const float d = xb->d;
+#else
+    iq1m_scale_t scale;
     scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
     const float d = scale.f16;
+#endif
     device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
     device const uint8_t * qh = xb->qh + 2*ib32 + il;
+#if QK_K == 64
+    const float dl  = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
+#else
     const float dl  = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
+#endif
     const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
     const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
     constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
index f717e616e6a2c3759308b8a97f3d6d7643b9c4e5..f2e6c4bd1a3216b21a2d071bc0c60ce22c9ab777 100644 (file)
@@ -3481,19 +3481,30 @@ void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, in
     float delta[4];
     uint16_t idx[4];
 
+#if QK_K != 64
     iq1m_scale_t scale;
+#endif
 
     for (int i = 0; i < nb; i++) {
 
         const uint16_t * sc = (const uint16_t *)x[i].scales;
+#if QK_K == 64
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+#else
         scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
         const float d = GGML_FP16_TO_FP32(scale.f16);
+#endif
         const uint8_t * qs = x[i].qs;
         const uint8_t * qh = x[i].qh;
 
         for (int ib = 0; ib < QK_K/32; ++ib) {
+#if QK_K == 64
+            const float dl1 = d * (2*((sc[ib/2] >> (8*(ib%2)+0)) & 0xf) + 1);
+            const float dl2 = d * (2*((sc[ib/2] >> (8*(ib%2)+4)) & 0xf) + 1);
+#else
             const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1);
             const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1);
+#endif
             idx[0] = qs[0] | ((qh[0] << 8) & 0x700);
             idx[1] = qs[1] | ((qh[0] << 4) & 0x700);
             idx[2] = qs[2] | ((qh[1] << 8) & 0x700);
@@ -9756,11 +9767,17 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
 
     const int nb = n / QK_K;
 
+#if QK_K != 64
     iq1m_scale_t scale;
+#endif
 
 #if defined __ARM_NEON
 
+#if QK_K == 64
+    const int32x4_t mask  = vdupq_n_s32(0xf);
+#else
     const int32x4_t mask  = vdupq_n_s32(0x7);
+#endif
     const int32x4_t mone  = vdupq_n_s32(1);
     const int32x4_t mzero = vdupq_n_s32(0);
 
@@ -9784,7 +9801,9 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
         const uint8_t  * qh = x[i].qh;
         const uint16_t * sc = (const uint16_t *)x[i].scales;
 
+#if QK_K != 64
         scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+#endif
 
         int32x4_t sumi1 = mzero;
         int32x4_t sumi2 = mzero;
@@ -9813,7 +9832,11 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
             const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
             const int32x4_t p34 = vpaddq_s32(p3, p4);
 
+#if QK_K == 64
+            int32x4_t scales_4 = ggml_vld1q_u32(sc[0] >> 0, sc[0] >> 4, sc[0] >> 8, sc[0] >> 12);
+#else
             int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
+#endif
             scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
 
             sumi1 = vmlaq_s32(sumi1, scales_4, p12);
@@ -9823,14 +9846,22 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
 
         }
 
+#if QK_K == 64
+        sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
+#else
         sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
+#endif
     }
 
     *s = sumf;
 
 #elif defined __AVX2__
 
+#if QK_K == 64
+    const __m256i mask = _mm256_set1_epi16(0xf);
+#else
     const __m256i mask = _mm256_set1_epi16(0x7);
+#endif
     const __m256i mone = _mm256_set1_epi16(1);
 
     __m256 accum1 = _mm256_setzero_ps();
@@ -9842,7 +9873,9 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
         const uint8_t  * qh = x[i].qh;
         const uint16_t * sc = (const uint16_t *)x[i].scales;
 
+#if QK_K != 64
         scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+#endif
 
         __m256i sumi1 = _mm256_setzero_si256();
         __m256i sumi2 = _mm256_setzero_si256();
@@ -9872,8 +9905,13 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
 
             const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
             const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
+#if QK_K == 64
+            __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >>  4), _mm_set1_epi16(sc[0] >> 0));
+            __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 12), _mm_set1_epi16(sc[0] >> 8));
+#else
             __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
             __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
+#endif
             scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
             scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
             const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
@@ -9887,7 +9925,11 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
             qs += 8; qh += 4;
         }
 
+#if QK_K == 64
+        const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
+#else
         const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
+#endif
         accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
         accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
 
@@ -9907,7 +9949,9 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
         const uint8_t  * qh = x[i].qh;
         const uint16_t * sc = (const uint16_t *)x[i].scales;
 
+#if QK_K != 64
         scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+#endif
 
         int sumi1 = 0, sumi2 = 0;
         for (int ib = 0; ib < QK_K/32; ++ib) {
@@ -9927,15 +9971,24 @@ void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void
                 sum1[l/2] += lsum1;
                 sum2[l/2] += lsum2*delta[l];
             }
+#if QK_K == 64
+            const int ls1 = 2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1;
+            const int ls2 = 2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1;
+#else
             const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
             const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
+#endif
             sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
             sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
             qs += 4;
             qh += 2;
         }
 
+#if QK_K == 64
+        sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
+#else
         sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
+#endif
     }
 
     *s = sumf;
@@ -11986,7 +12039,9 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
 
     for (int ibl = 0; ibl < nbl; ++ibl) {
 
-        //y[ibl].d = GGML_FP32_TO_FP16(0.f);
+#if QK_K == 64
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+#endif
         memset(y[ibl].qs, 0, QK_K/8);
         memset(y[ibl].qh, 0, QK_K/16);
         memset(y[ibl].scales, 0, QK_K/32);
@@ -12161,13 +12216,22 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
         }
 
         uint16_t * sc = (uint16_t *)y[ibl].scales;
+#if QK_K == 64
+        float d = max_scale/31;
+#else
         float d = max_scale/15;
+#endif
         float id = 1/d;
         float sumqx_f = 0, sumq2_f = 0;
         for (int ib = 0; ib < QK_K/block_size; ++ib) {
             int l = nearest_int(0.5f*(id*scales[ib+0]-1));
+#if QK_K == 64
+            l = MAX(0, MIN(15, l));
+            sc[ib/4] |= (l << 4*(ib%4));
+#else
             l = MAX(0, MIN(7, l));
             sc[ib/4] |= (l << 3*(ib%4));
+#endif
             y[ibl].qh[ib] |= masks[shifts[ib]];
             const float * xb = xbl + block_size*ib;
             if (quant_weights) {
@@ -12190,10 +12254,14 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy
         }
         if (sumq2_f > 0) d = sumqx_f/sumq2_f;
         s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
+#if QK_K == 64
+        y[ibl].d = s.f16;
+#else
         sc[0] |= ((s.u16 & 0x000f) << 12);
         sc[1] |= ((s.u16 & 0x00f0) <<  8);
         sc[2] |= ((s.u16 & 0x0f00) <<  4);
         sc[3] |= ((s.u16 & 0xf000) <<  0);
+#endif
     }
 }