]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp
authorGeorgi Gerganov <redacted>
Sat, 15 Apr 2023 16:50:54 +0000 (19:50 +0300)
committerGeorgi Gerganov <redacted>
Sat, 15 Apr 2023 16:50:54 +0000 (19:50 +0300)
include/ggml/ggml.h
src/ggml.c

index 617298a95536dd1a1f4c8ed8568ef575ab01cef8..241e96a1975b1c935fa96fabd2585b31f905e2b0 100644 (file)
@@ -204,6 +204,7 @@ enum ggml_type {
     GGML_TYPE_F16  = 1,
     GGML_TYPE_Q4_0 = 2,
     GGML_TYPE_Q4_1 = 3,
+    GGML_TYPE_Q8_0 = 4,
     GGML_TYPE_I8,
     GGML_TYPE_I16,
     GGML_TYPE_I32,
@@ -836,6 +837,7 @@ typedef struct {
     dequantize_row_q_t dequantize_row_q;
     quantize_row_q_t   quantize_row_q;
     quantize_row_q_t   quantize_row_q_reference;
+    quantize_row_q_t   quantize_row_q_dot;
     vec_dot_q_t        vec_dot_q;
 } quantize_fns_t;
 
index cf6a81f43cfed5ad5278eeae330e7f3eb60f947f..ccad76e8c018e31352d4c599277fccc41adaf1b3 100644 (file)
@@ -427,8 +427,6 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 // quantization
 //
 
-#define QK 32
-
 // AVX routines provided by GH user Const-me
 // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
 #if __AVX2__ || __AVX512F__
@@ -571,37 +569,42 @@ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
 #endif
 #endif
 
-// method 5
-// blocks of QK elements
-// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
+
+#define QK4_0 32
 typedef struct {
-    float   d; // delta
-    uint8_t qs[QK / 2]; // nibbles / quants
+    float   d;          // delta
+    uint8_t qs[QK4_0 / 2];  // nibbles / quants
 } block_q4_0;
-static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
+static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
 
-// method 4
-// blocks of QK elements
-// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
+#define QK4_1 32
 typedef struct {
-    float   d;
-    float   m;
-    uint8_t qs[QK / 2]; // nibbles / quants
+    float   d;          // delta
+    float   m;          // min
+    uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
-static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
+static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK8_0 32
+typedef struct {
+    float   d;          // delta
+    int8_t  qs[QK8_0];  // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+
 
 // reference implementation for deterministic creation of model files
 static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_0 == 0);
+    const int nb = k / QK4_0;
 
-    uint8_t pp[QK/2];
+    uint8_t pp[QK4_0/2];
 
     for (int i = 0; i < nb; i++) {
         float amax = 0.0f; // absolute max
 
-        for (int l = 0; l < QK; l++) {
-            const float v = x[i*QK + l];
+        for (int l = 0; l < QK4_0; l++) {
+            const float v = x[i*QK4_0 + l];
             amax = MAX(amax, fabsf(v));
         }
 
@@ -610,9 +613,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
 
         y[i].d = d;
 
-        for (int l = 0; l < QK; l += 2) {
-            const float v0 = x[i*QK + l + 0]*id;
-            const float v1 = x[i*QK + l + 1]*id;
+        for (int l = 0; l < QK4_0; l += 2) {
+            const float v0 = x[i*QK4_0 + l + 0]*id;
+            const float v1 = x[i*QK4_0 + l + 1]*id;
 
             const uint8_t vi0 = (int8_t)roundf(v0) + 8;
             const uint8_t vi1 = (int8_t)roundf(v1) + 8;
@@ -628,8 +631,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
 }
 
 static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_0 == 0);
+    const int nb = k / QK4_0;
 
     block_q4_0 * restrict y = vy;
 
@@ -879,19 +882,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
 }
 
 static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_1 == 0);
+    const int nb = k / QK4_1;
 
     block_q4_1 * restrict y = vy;
 
-    uint8_t pp[QK/2];
+    uint8_t pp[QK4_1/2];
 
     for (int i = 0; i < nb; i++) {
         float min = FLT_MAX;
         float max = -FLT_MAX;
 
-        for (int l = 0; l < QK; l++) {
-            const float v = x[i*QK + l];
+        for (int l = 0; l < QK4_1; l++) {
+            const float v = x[i*QK4_1 + l];
             if (v < min) min = v;
             if (v > max) max = v;
         }
@@ -902,9 +905,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
         y[i].d = d;
         y[i].m = min;
 
-        for (int l = 0; l < QK; l += 2) {
-            const float v0 = (x[i*QK + l + 0] - min)*id;
-            const float v1 = (x[i*QK + l + 1] - min)*id;
+        for (int l = 0; l < QK4_1; l += 2) {
+            const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
+            const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
 
             const uint8_t vi0 = roundf(v0);
             const uint8_t vi1 = roundf(v1);
@@ -920,9 +923,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
 }
 
 static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
-    assert(k % QK == 0);
+    assert(k % QK4_1 == 0);
 
-    const int nb = k / QK;
+    const int nb = k / QK4_1;
 
     block_q4_1 * restrict y = vy;
 
@@ -1006,7 +1009,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
         float32x4_t minv[8];
         float32x4_t maxv[8];
 
-        for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
+        for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
 
         for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
         for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
@@ -1042,9 +1045,160 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
 #endif
 }
 
+// reference implementation for deterministic creation of model files
+static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int l = 0; l < QK8_0; l++) {
+            const float v = x[i*QK8_0 + l];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        for (int l = 0; l < QK8_0; ++l) {
+            const float   v  = x[i*QK8_0 + l]*id;
+            y[i].qs[l] = roundf(v);
+        }
+    }
+}
+
+static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int l = 0; l < 8; l++) srcv[l]  = vld1q_f32(x + i*32 + 4*l);
+        for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
+
+        for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
+        for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
+        for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = d;
+
+        for (int l = 0; l < 8; l++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[l], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 127.f;
+        y[i].d = d;
+        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );     // 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );     // 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );     // 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#else
+    // scalar
+    quantize_row_q8_0_reference(x, y, k);
+#endif
+}
+
 static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_0 == 0);
+    const int nb = k / QK4_0;
 
     const block_q4_0 * restrict x = vx;
 
@@ -1055,7 +1209,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
 
         const uint8_t * restrict pp = x[i].qs;
 
-        for (int l = 0; l < QK; l += 32) {
+        for (int l = 0; l < QK4_0; l += 32) {
             // Load 32x4-bit integers into 32x8-bit integers
             __m256i vx8 = bytesFromNibbles(pp+l/2);
 
@@ -1077,7 +1231,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
             // Scale and store
             for (int j = 0; j < 4; j++) {
                 const __m256 result = _mm256_mul_ps(vf[j], d_v);
-                _mm256_storeu_ps(y + i * QK + l + j*8, result);
+                _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result);
             }
         }
     }
@@ -1087,7 +1241,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
 
         const uint8_t * restrict pp = x[i].qs;
 
-        for (int l = 0; l < QK; l += 16) {
+        for (int l = 0; l < QK4_0; l += 16) {
             // Load 16x4-bit integers into 8x8-bit integers
             const uint8x8_t v8 = vld1_u8(pp + l/2);
 
@@ -1126,10 +1280,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
             const float32x4_t r3 = vmulq_f32(vf_3, vd);
 
             // Store
-            vst1q_f32(y + i*QK + l +  0, r0);
-            vst1q_f32(y + i*QK + l +  4, r1);
-            vst1q_f32(y + i*QK + l +  8, r2);
-            vst1q_f32(y + i*QK + l + 12, r3);
+            vst1q_f32(y + i*QK4_0 + l +  0, r0);
+            vst1q_f32(y + i*QK4_0 + l +  4, r1);
+            vst1q_f32(y + i*QK4_0 + l +  8, r2);
+            vst1q_f32(y + i*QK4_0 + l + 12, r3);
         }
     }
 #else
@@ -1139,7 +1293,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
 
         const uint8_t * restrict pp = x[i].qs;
 
-        for (int l = 0; l < QK; l += 2) {
+        for (int l = 0; l < QK4_0; l += 2) {
             const uint8_t vi = pp[l/2];
 
             const int8_t vi0 = vi & 0xf;
@@ -1150,19 +1304,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
 
             //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
 
-            y[i*QK + l + 0] = v0;
-            y[i*QK + l + 1] = v1;
+            y[i*QK4_0 + l + 0] = v0;
+            y[i*QK4_0 + l + 1] = v1;
 
-            assert(!isnan(y[i*QK + l + 0]));
-            assert(!isnan(y[i*QK + l + 1]));
+            assert(!isnan(y[i*QK4_0 + l + 0]));
+            assert(!isnan(y[i*QK4_0 + l + 1]));
         }
     }
 #endif
 }
 
 static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_1 == 0);
+    const int nb = k / QK4_1;
 
     const block_q4_1 * restrict x = vx;
 
@@ -1173,7 +1327,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
 
         const uint8_t * restrict pp = x[i].qs;
 
-        for (int l = 0; l < QK; l += 32) {
+        for (int l = 0; l < QK4_1; l += 32) {
             // Load 32x4-bit integers into 32x8-bit integers
             __m256i vx8 = bytesFromNibbles(pp+l/2);
 
@@ -1192,7 +1346,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
             // Scale, add m and store
             for (int j = 0; j < 4; j++) {
                 const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
-                _mm256_storeu_ps(y + i * QK + l + j*8, result);
+                _mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
             }
         }
     }
@@ -1203,7 +1357,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
 
         const uint8_t * restrict pp = x[i].qs;
 
-        for (int l = 0; l < QK; l += 16) {
+        for (int l = 0; l < QK4_1; l += 16) {
             // Load 16x4-bit integers into 8x8-bit integers
             const uint8x8_t v8 = vld1_u8(pp + l/2);
 
@@ -1234,10 +1388,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
             const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
 
             // Store
-            vst1q_f32(y + i*QK + l +  0, r0);
-            vst1q_f32(y + i*QK + l +  4, r1);
-            vst1q_f32(y + i*QK + l +  8, r2);
-            vst1q_f32(y + i*QK + l + 12, r3);
+            vst1q_f32(y + i*QK4_1 + l +  0, r0);
+            vst1q_f32(y + i*QK4_1 + l +  4, r1);
+            vst1q_f32(y + i*QK4_1 + l +  8, r2);
+            vst1q_f32(y + i*QK4_1 + l + 12, r3);
         }
     }
 #else
@@ -1247,7 +1401,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
 
         const uint8_t * restrict pp = x[i].qs;
 
-        for (int l = 0; l < QK; l += 2) {
+        for (int l = 0; l < QK4_1; l += 2) {
             const uint8_t vi = pp[l/2];
 
             const int8_t vi0 = vi & 0xf;
@@ -1256,11 +1410,11 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
             const float v0 = vi0*d + m;
             const float v1 = vi1*d + m;
 
-            y[i*QK + l + 0] = v0;
-            y[i*QK + l + 1] = v1;
+            y[i*QK4_1 + l + 0] = v0;
+            y[i*QK4_1 + l + 1] = v1;
 
-            assert(!isnan(y[i*QK + l + 0]));
-            assert(!isnan(y[i*QK + l + 1]));
+            assert(!isnan(y[i*QK4_1 + l + 0]));
+            assert(!isnan(y[i*QK4_1 + l + 1]));
         }
     }
 #endif
@@ -1822,7 +1976,7 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     *s = sumf;
 }
 
-#if __AVX512F__ && QK == 32
+#if __AVX512F__ && QK4_0 == 32
 static inline __m512 dot_q4_0_oneblock_avx512(
     __m512 acc,
     const block_q4_0 * restrict x,
@@ -1890,9 +2044,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 }
 
 static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int nb = n / QK;
+    const int nb = n / QK4_0;
 
-    assert(n % QK == 0);
+    assert(n % QK4_0 == 0);
     assert(nb % 2 == 0);
 
     const block_q4_0 * restrict x = vx;
@@ -2215,7 +2369,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
         const uint8_t * restrict p1 = y[i].qs;
 
         int sumi = 0;
-        for (int j = 0; j < QK/2; j++) {
+        for (int j = 0; j < QK4_0/2; j++) {
             const uint8_t v0 = p0[j];
             const uint8_t v1 = p1[j];
 
@@ -2235,7 +2389,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
 }
 
 static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int nb = n / QK;
+    const int nb = n / QK4_1;
 
     const block_q4_1 * restrict x = vx;
     const block_q4_1 * restrict y = vy;
@@ -2312,7 +2466,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
     res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
     res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
 
-    sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
+    sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
 #elif defined(__ARM_NEON)
     float sum00 = 0.0f;
     float sum01 = 0.0f;
@@ -2344,12 +2498,12 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
         const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
 
         sum00 += x0->m*y0->m;
-        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
-        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
+        sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
+        sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
 
         sum00 += x1->m*y1->m;
-        sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
-        sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
+        sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
+        sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
 
 #if defined(__ARM_FEATURE_DOTPROD)
         // dot product into int32x4_t
@@ -2386,7 +2540,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
 #endif
     }
 
-    sumf = QK*sum00 + sum01 + sum10 + sum11;
+    sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
@@ -2399,7 +2553,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
         const uint8_t * restrict p0 = x[i].qs;
         const uint8_t * restrict p1 = y[i].qs;
 
-        for (int j = 0; j < QK/2; j++) {
+        for (int j = 0; j < QK4_1/2; j++) {
             const uint8_t v0 = p0[j];
             const uint8_t v1 = p1[j];
 
@@ -2417,6 +2571,209 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
     *s = sumf;
 }
 
+static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_0;
+
+    assert(n % QK8_0 == 0);
+    assert(nb % 2 == 0);
+
+    const block_q4_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+    float sumf = 0.0;
+
+#if defined(__ARM_NEON)
+    float sum0 = 0.0f;
+    float sum1 = 0.0f;
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_0 * restrict x0 = &x[i + 0];
+        const block_q4_0 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b   = vdupq_n_u8(0xf);
+        const int8x16_t  s8b   = vdupq_n_s8(0x8);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        // interleave
+        const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
+        const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
+        const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
+        const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
+        int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
+
+        p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
+        p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
+
+        sum0 += x0->d*y0->d*vaddvq_s32(p_0);
+        sum1 += x1->d*y1->d*vaddvq_s32(p_1);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
+
+        const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
+        const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
+
+        const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
+        const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
+
+        const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
+        const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
+
+        sum0 += x0->d*y0->d*vaddvq_s16(p_0);
+        sum1 += x1->d*y1->d*vaddvq_s16(p_1);
+#endif
+    }
+
+    sumf = sum0 + sum1;
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+
+        __m256i bx = bytesFromNibbles(x[i].qs);
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8( 8 );
+        bx = _mm256_sub_epi8( bx, off );
+
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        // Get absolute values of x vectors
+        const __m256i ax = _mm256_sign_epi8(bx, bx);
+
+        // Sign the values of the y vectors
+        const __m256i sy = _mm256_sign_epi8(by, bx);
+
+        // Perform multiplication and create 16-bit values
+        const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+
+        const __m256i ones = _mm256_set1_epi16(1);
+        __m256i xy_q = _mm256_madd_epi16(ones, dot);
+
+        /* Convert to vectore of 8 int32_t to 8 floats */
+        __m256 q = _mm256_cvtepi32_ps( xy_q );
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps( d, q, acc );
+    }
+
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps( acc, 1 );
+    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+
+    sumf = _mm_cvtss_f32( res );
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+
+        __m128i i32[2];
+        for (int j = 0; j < 2; ++j) {
+            // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
+            __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
+            __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
+
+            // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+            const __m128i off = _mm_set1_epi8( 8 );
+            bx = _mm_sub_epi8( bx, off );
+
+            // Get absolute values of x vectors
+            const __m128i ax = _mm_sign_epi8(bx, bx);
+
+            // Sign the values of the y vectors
+            const __m128i sy = _mm_sign_epi8(by, bx);
+
+            // Perform multiplication and create 16-bit values
+            const __m128i dot = _mm_maddubs_epi16(ax, sy);
+
+            const __m128i ones = _mm_set1_epi16(1);
+            i32[j] = _mm_madd_epi16(ones, dot);
+        }
+
+        // Convert int32_t to float
+        __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
+        // Apply the scale, and accumulate
+        acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
+    }
+
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps( acc, 1 );
+    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+
+    sumf = _mm_cvtss_f32( res );
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d0 = x[i].d;
+        const float d1 = y[i].d;
+
+        const uint8_t * restrict p0 = x[i].qs;
+        const  int8_t * restrict p1 = y[i].qs;
+
+        int sumi = 0;
+        for (int j = 0; j < QK8_0/2; j++) {
+            const uint8_t v0 = p0[j];
+
+            const int i0 = (int8_t) (v0 & 0xf) - 8;
+            const int i1 = (int8_t) (v0 >> 4)  - 8;
+
+            const int i2 = p1[2*j + 0];
+            const int i3 = p1[2*j + 1];
+
+            sumi += i0*i2 + i1*i3;
+        }
+        sumf += d0*d1*sumi;
+    }
+#endif
+
+    *s = sumf;
+}
+
 // compute GGML_VEC_DOT_UNROLL dot products at once
 // xs - x row stride in bytes
 inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -2661,24 +3018,26 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
 static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = 1,
     [GGML_TYPE_F16]  = 1,
-    [GGML_TYPE_Q4_0] = QK,
-    [GGML_TYPE_Q4_1] = QK,
+    [GGML_TYPE_Q4_0] = QK4_0,
+    [GGML_TYPE_Q4_1] = QK4_1,
+    [GGML_TYPE_Q8_0] = QK8_0,
     [GGML_TYPE_I8]   = 1,
     [GGML_TYPE_I16]  = 1,
     [GGML_TYPE_I32]  = 1,
 };
-static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
 
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = sizeof(float),
     [GGML_TYPE_F16]  = sizeof(ggml_fp16_t),
     [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
     [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
+    [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
     [GGML_TYPE_I8]   = sizeof(int8_t),
     [GGML_TYPE_I16]  = sizeof(int16_t),
     [GGML_TYPE_I32]  = sizeof(int32_t),
 };
-static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
 
 
 static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -2686,11 +3045,12 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F16]  = "f16",
     [GGML_TYPE_Q4_0] = "q4_0",
     [GGML_TYPE_Q4_1] = "q4_1",
+    [GGML_TYPE_Q8_0] = "q8_0",
     [GGML_TYPE_I8]   = "i8",
     [GGML_TYPE_I16]  = "i16",
     [GGML_TYPE_I32]  = "i32",
 };
-static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_NAME is outdated");
+static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
 
 static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "NONE",
@@ -3363,14 +3723,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
     char * const data = tensor->data;
 
     switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
         case GGML_TYPE_I8:
             {
                 assert(tensor->nb[0] == sizeof(int8_t));
@@ -3406,7 +3758,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
                     ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
                 }
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -3423,14 +3775,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
     char * const data = tensor->data;
 
     switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
         case GGML_TYPE_I8:
             {
                 assert(tensor->nb[0] == sizeof(int8_t));
@@ -3466,7 +3810,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
                     ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
                 }
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -3477,14 +3821,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
 
 int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
     switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3510,7 +3846,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
                 GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 return ((float *)(tensor->data))[i];
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -3521,14 +3857,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
 
 void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
     switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3554,7 +3882,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
                 GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 ((float *)(tensor->data))[i] = value;
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -3563,14 +3891,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
 
 float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
     switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3596,7 +3916,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
                 GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 return ((float *)(tensor->data))[i];
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -3607,14 +3927,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
 
 void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
     switch (tensor->type) {
-        case GGML_TYPE_Q4_0:
-            {
-                GGML_ASSERT(false);
-            } break;
-        case GGML_TYPE_Q4_1:
-            {
-                GGML_ASSERT(false);
-            } break;
         case GGML_TYPE_I8:
             {
                 GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3640,7 +3952,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
                 GGML_ASSERT(tensor->nb[0] == sizeof(float));
                 ((float *)(tensor->data))[i] = value;
             } break;
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5435,12 +5747,7 @@ static void ggml_compute_forward_dup(
             {
                 ggml_compute_forward_dup_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5516,13 +5823,7 @@ static void ggml_compute_forward_add(
             {
                 ggml_compute_forward_add_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5568,13 +5869,7 @@ static void ggml_compute_forward_sub(
             {
                 ggml_compute_forward_sub_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5620,13 +5915,7 @@ static void ggml_compute_forward_mul(
             {
                 ggml_compute_forward_mul_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5672,13 +5961,7 @@ static void ggml_compute_forward_div(
             {
                 ggml_compute_forward_div_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5720,13 +6003,7 @@ static void ggml_compute_forward_sqr(
             {
                 ggml_compute_forward_sqr_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5768,13 +6045,7 @@ static void ggml_compute_forward_sqrt(
             {
                 ggml_compute_forward_sqrt_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5826,13 +6097,7 @@ static void ggml_compute_forward_sum(
             {
                 ggml_compute_forward_sum_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5903,13 +6168,7 @@ static void ggml_compute_forward_mean(
             {
                 ggml_compute_forward_mean_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -5967,13 +6226,7 @@ static void ggml_compute_forward_repeat(
             {
                 ggml_compute_forward_repeat_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6015,13 +6268,7 @@ static void ggml_compute_forward_abs(
             {
                 ggml_compute_forward_abs_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6063,13 +6310,7 @@ static void ggml_compute_forward_sgn(
             {
                 ggml_compute_forward_sgn_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6111,13 +6352,7 @@ static void ggml_compute_forward_neg(
             {
                 ggml_compute_forward_neg_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6159,13 +6394,7 @@ static void ggml_compute_forward_step(
             {
                 ggml_compute_forward_step_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6207,13 +6436,7 @@ static void ggml_compute_forward_relu(
             {
                 ggml_compute_forward_relu_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6272,13 +6495,7 @@ static void ggml_compute_forward_gelu(
             {
                 ggml_compute_forward_gelu_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6339,13 +6556,7 @@ static void ggml_compute_forward_silu(
             {
                 ggml_compute_forward_silu_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6425,13 +6636,7 @@ static void ggml_compute_forward_norm(
             {
                 ggml_compute_forward_norm_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6505,13 +6710,7 @@ static void ggml_compute_forward_rms_norm(
             {
                 ggml_compute_forward_rms_norm_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -6908,14 +7107,17 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
         .dequantize_row_q         = dequantize_row_q4_0,
         .quantize_row_q           = quantize_row_q4_0,
         .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
-        .vec_dot_q                = ggml_vec_dot_q4_0,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q4_0_q8_0,
     },
     [GGML_TYPE_Q4_1] = {
         .dequantize_row_q         = dequantize_row_q4_1,
         .quantize_row_q           = quantize_row_q4_1,
         .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
+        .quantize_row_q_dot       = quantize_row_q4_1,
         .vec_dot_q                = ggml_vec_dot_q4_1,
     },
+    // TODO: GGML_TYPE_Q8_0
 };
 
 // For internal test use
@@ -6971,8 +7173,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
     GGML_ASSERT(ne3  == ne13);
 
     const enum ggml_type type = src0->type;
-    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
-    vec_dot_q_t      const vec_dot_q      = quantize_fns[type].vec_dot_q;
+    quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
+    vec_dot_q_t      const vec_dot_q          = quantize_fns[type].vec_dot_q;
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -7041,12 +7243,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
     if (params->type == GGML_TASK_INIT) {
         char * wdata = params->wdata;
-        const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
+        const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
 
         for (int64_t i13 = 0; i13 < ne13; ++i13) {
             for (int64_t i12 = 0; i12 < ne12; ++i12) {
                 for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                    quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                    quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
                     wdata += row_size;
                 }
             }
@@ -7072,7 +7274,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     const int ir1 = MIN(ir0 + dr, nr);
 
     void * wdata = params->wdata;
-    const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
+    const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
 
     for (int ir = ir0; ir < ir1; ++ir) {
         // src0 indices
@@ -7120,6 +7322,7 @@ static void ggml_compute_forward_mul_mat(
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q8_0:
             {
                 ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
             } break;
@@ -7131,10 +7334,7 @@ static void ggml_compute_forward_mul_mat(
             {
                 ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -7216,13 +7416,7 @@ static void ggml_compute_forward_scale(
             {
                 ggml_compute_forward_scale_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -7383,6 +7577,7 @@ static void ggml_compute_forward_get_rows(
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q8_0:
             {
                 ggml_compute_forward_get_rows_q(params, src0, src1, dst);
             } break;
@@ -7394,10 +7589,7 @@ static void ggml_compute_forward_get_rows(
             {
                 ggml_compute_forward_get_rows_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -7470,13 +7662,7 @@ static void ggml_compute_forward_diag_mask_inf(
             {
                 ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -7564,13 +7750,7 @@ static void ggml_compute_forward_soft_max(
             {
                 ggml_compute_forward_soft_max_f32(params, src0, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -7747,12 +7927,7 @@ static void ggml_compute_forward_rope(
             {
                 ggml_compute_forward_rope_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -8015,12 +8190,7 @@ static void ggml_compute_forward_conv_1d_1s(
             {
                 ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -8283,12 +8453,7 @@ static void ggml_compute_forward_conv_1d_2s(
             {
                 ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -8768,12 +8933,7 @@ static void ggml_compute_forward_flash_attn(
             {
                 ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -8979,12 +9139,7 @@ static void ggml_compute_forward_flash_ff(
             {
                 GGML_ASSERT(false); // TODO
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -9028,13 +9183,7 @@ static void ggml_compute_forward_map_unary(
             {
                 ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -9083,13 +9232,7 @@ static void ggml_compute_forward_map_binary(
             {
                 ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
             } break;
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_F16:
-        case GGML_TYPE_COUNT:
+        default:
             {
                 GGML_ASSERT(false);
             } break;
@@ -9914,7 +10057,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                             } else
 #endif
                             {
-                                cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
+                                cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
                             }
                         } else {
                             GGML_ASSERT(false);
@@ -11089,16 +11232,16 @@ enum ggml_opt_result ggml_opt(
 ////////////////////////////////////////////////////////////////////////////////
 
 size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_0 == 0);
+    const int nb = k / QK4_0;
 
     for (int j = 0; j < n; j += k) {
-        block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
+        block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0;
 
         quantize_row_q4_0_reference(src + j, y, k);
 
         for (int i = 0; i < nb; i++) {
-            for (int l = 0; l < QK; l += 2) {
+            for (int l = 0; l < QK4_0; l += 2) {
                 const uint8_t vi0 = y[i].qs[l/2] & 0xF;
                 const uint8_t vi1 = y[i].qs[l/2] >> 4;
 
@@ -11108,20 +11251,20 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
         }
     }
 
-    return (n/QK*sizeof(block_q4_0));
+    return (n/QK4_0*sizeof(block_q4_0));
 }
 
 size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
-    assert(k % QK == 0);
-    const int nb = k / QK;
+    assert(k % QK4_1 == 0);
+    const int nb = k / QK4_1;
 
     for (int j = 0; j < n; j += k) {
-        block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
+        block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1;
 
         quantize_row_q4_1_reference(src + j, y, k);
 
         for (int i = 0; i < nb; i++) {
-            for (int l = 0; l < QK; l += 2) {
+            for (int l = 0; l < QK4_1; l += 2) {
                 const uint8_t vi0 = y[i].qs[l/2] & 0xF;
                 const uint8_t vi1 = y[i].qs[l/2] >> 4;
 
@@ -11131,7 +11274,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
         }
     }
 
-    return (n/QK*sizeof(block_q4_1));
+    return (n/QK4_1*sizeof(block_q4_1));
 }
 
 ////////////////////////////////////////////////////////////////////////////////