]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp (cuBLAS, Q4_3, bug fix, etc)
authorGeorgi Gerganov <redacted>
Thu, 20 Apr 2023 19:00:49 +0000 (22:00 +0300)
committerGeorgi Gerganov <redacted>
Thu, 20 Apr 2023 19:00:49 +0000 (22:00 +0300)
include/ggml/ggml.h
src/ggml.c

index 570147fc246587603042beed51a0cdb61bae17f1..a8a7b6b4ff504c6dab5b7f0d03cbb2ef29934815 100644 (file)
@@ -205,7 +205,8 @@ enum ggml_type {
     GGML_TYPE_Q4_0 = 2,
     GGML_TYPE_Q4_1 = 3,
     GGML_TYPE_Q4_2 = 4,
-    GGML_TYPE_Q8_0 = 5,
+    GGML_TYPE_Q4_3 = 5,
+    GGML_TYPE_Q8_0 = 6,
     GGML_TYPE_I8,
     GGML_TYPE_I16,
     GGML_TYPE_I32,
@@ -360,6 +361,8 @@ const char * ggml_type_name(enum ggml_type type);
 
 size_t ggml_element_size(const struct ggml_tensor * tensor);
 
+bool ggml_is_quantized(enum ggml_type type);
+
 struct ggml_context * ggml_init(struct ggml_init_params params);
 void ggml_free(struct ggml_context * ctx);
 
@@ -808,6 +811,9 @@ enum ggml_opt_result ggml_opt(
 size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
 size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
 size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist);
+size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist);
+
+size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
 
 //
 // system info
index 3b38eaad3673612fe62cbcd11c4b9588d944062c..8109b36b2443264ad3889ec31ac0d5d13deadc42 100644 (file)
@@ -19,6 +19,7 @@
 #include <inttypes.h>
 #include <stdio.h>
 #include <float.h>
+#include <limits.h>
 
 // if C99 - static_assert is noop
 // ref: https://stackoverflow.com/a/53923785/4039976
@@ -149,23 +150,25 @@ inline static void* ggml_aligned_malloc(size_t size) {
 #elif defined(GGML_USE_CUBLAS)
 #include <cublas_v2.h>
 #include <cuda_runtime.h>
-#define CUDA_CHECK(err)                                                                            \
-    do {                                                                                           \
-        cudaError_t err_ = (err);                                                                  \
-        if (err_ != cudaSuccess) {                                                                 \
-            printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,                       \
-                cudaGetErrorString(err_));                                                         \
-            exit(1);                                                                               \
-        }                                                                                          \
+#include "ggml-cuda.h"
+
+#define CUDA_CHECK(err)                                                        \
+    do {                                                                       \
+        cudaError_t err_ = (err);                                              \
+        if (err_ != cudaSuccess) {                                             \
+            printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__,   \
+                cudaGetErrorString(err_));                                     \
+            exit(1);                                                           \
+        }                                                                      \
     } while (0)
 
-#define CUBLAS_CHECK(err)                                                                          \
-    do {                                                                                           \
-        cublasStatus_t err_ = (err);                                                               \
-        if (err_ != CUBLAS_STATUS_SUCCESS) {                                                       \
-            printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);                        \
-            exit(1);                                                                               \
-        }                                                                                          \
+#define CUBLAS_CHECK(err)                                                      \
+    do {                                                                       \
+        cublasStatus_t err_ = (err);                                           \
+        if (err_ != CUBLAS_STATUS_SUCCESS) {                                   \
+            printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
+            exit(1);                                                           \
+        }                                                                      \
     } while (0)
 
 static cublasHandle_t cublasH = NULL;
@@ -176,6 +179,7 @@ static void init_cublas(void) {
         CUBLAS_CHECK(cublasCreate(&cublasH));
 
         CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
+
         CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
 
         // configure logging to stdout
@@ -463,12 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
 // quantization
 //
 
-// AVX routines provided by GH user Const-me
-// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
+#if __AVX__ || __AVX2__ || __AVX512F__
+// Unpack 16 4-bit fields into 16 bytes
+// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
+static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
+{
+    // Load 8 bytes from memory
+    __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
+
+    // Expand bytes into uint16_t values
+    __m128i bytes = _mm_cvtepu8_epi16( tmp );
+
+    // Unpack values into individual bytes
+    const __m128i lowMask = _mm_set1_epi8( 0xF );
+    __m128i high = _mm_andnot_si128( lowMask, bytes );
+    __m128i low = _mm_and_si128( lowMask, bytes );
+    high = _mm_slli_epi16( high, 4 );
+    bytes = _mm_or_si128( low, high );
+    return bytes;
+}
+
 #if __AVX2__ || __AVX512F__
 // Unpack 32 4-bit fields into 32 bytes
 // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
-static inline __m256i bytesFromNibbles( const uint8_t* rsi )
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
 {
     // Load 16 bytes from memory
     __m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -499,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
     __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
     return _mm_packus_epi16( r0, r1 );
 }
-#elif __AVX__
-static inline __m128i bytesFromNibbles( const uint8_t* rsi )
-{
-    // Load 8 bytes from memory
-    __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
-
-    // Expand bytes into uint16_t values
-    __m128i bytes = _mm_cvtepu8_epi16( tmp );
-
-    // Unpack values into individual bytes
-    const __m128i lowMask = _mm_set1_epi8( 0xF );
-    __m128i high = _mm_andnot_si128( lowMask, bytes );
-    __m128i low = _mm_and_si128( lowMask, bytes );
-    high = _mm_slli_epi16( high, 4 );
-    bytes = _mm_or_si128( low, high );
-    return bytes;
-}
-
+#else
 static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -533,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
     return _mm_packus_epi16( bytes1, bytes2);
 }
 #endif
+#endif // __AVX__ || __AVX2__ || __AVX512F__
 
 #if __ARM_NEON
 
@@ -631,7 +637,7 @@ typedef struct {
     float   m;          // min
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
-static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
+static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
 
 #define QK4_2 16
 typedef struct {
@@ -640,6 +646,14 @@ typedef struct {
 } block_q4_2;
 static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
 
+#define QK4_3 16
+typedef struct {
+    ggml_fp16_t d;         // delta
+    ggml_fp16_t m;         // min
+    uint8_t qs[QK4_3 / 2]; // nibbles / quants
+} block_q4_3;
+static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
+
 #define QK8_0 32
 typedef struct {
     float   d;          // delta
@@ -1135,12 +1149,136 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
     }
 }
 
+static inline int nearest_int(float fval) {
+    assert(fval <= 4194303.f);
+    float val = fval + 12582912.f;
+    int i; memcpy(&i, &val, sizeof(int));
+    return (i & 0x007fffff) - 0x00400000;
+}
+
+static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
+        const float * restrict candidates, int8_t * restrict L) {
+    assert (nmin >= INT8_MIN);
+    assert (nmax <= INT8_MAX);
+    float amax = 0;
+    for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
+    if (!amax) { // all zero
+        for (int i=0; i<n; ++i) L[i] = 0;
+        return 1.f;
+    }
+    float best = 0, bestScale = 0;
+    for (int si=0; si<nCandidates; ++si) {
+        float iscale = candidates[si]/amax;
+        float sumlxP = 0; int suml2P = 0;
+        float sumlxM = 0; int suml2M = 0;
+        for (int i=0; i<n; ++i) {
+            int l = nearest_int(iscale*X[i]);
+            int lp = MAX(nmin, MIN(nmax, +l));
+            int lm = MAX(nmin, MIN(nmax, -l));
+            sumlxP += X[i]*lp; suml2P += lp*lp;
+            sumlxM += X[i]*lm; suml2M += lm*lm;
+        }
+        float sumlxP2 = sumlxP*sumlxP;
+        float sumlxM2 = sumlxM*sumlxM;
+        if (sumlxP2*suml2M > sumlxM2*suml2P) {
+            if (sumlxP2 > best*suml2P) {
+                best = sumlxP2/suml2P; bestScale = iscale;
+            }
+        } else {
+            if (sumlxM2 > best*suml2M) {
+                best = sumlxM2/suml2M; bestScale = -iscale;
+            }
+        }
+    }
+    float sumlx = 0; int suml2 = 0;
+    for (int i=0; i<n; ++i) {
+        int l = nearest_int(bestScale*X[i]);
+        l = MAX(nmin, MIN(nmax, l));
+        sumlx += X[i]*l; suml2 += l*l;
+        L[i] = l;
+    }
+    float scale = sumlx/suml2;
+    return scale;
+}
+
+static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
+#define CANDIDATE_COUNT 8
+    static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
+    assert(k % QK4_2 == 0);
+
+    int8_t L[QK4_2];
+
+    const int nb = k / QK4_2;
+
+    for (int i = 0; i < nb; i++) {
+        float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
+        y[i].d = GGML_FP32_TO_FP16(scale);
+
+        for (int l = 0; l < QK4_2; l += 2) {
+            const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
+            const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
+
+            assert(vi0 < 16);
+            assert(vi1 < 16);
+
+            y[i].qs[l/2] = vi0 | (vi1 << 4);
+        }
+
+        x += QK4_2;
+    }
+}
+
 static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
     assert(k % QK4_2 == 0);
 
     block_q4_2 * restrict y = vy;
 
-    quantize_row_q4_2_reference(x, y, k);
+    //quantize_row_q4_2_reference(x, y, k);
+    // This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
+    quantize_row_q4_2_rmse(x, y, k);
+}
+
+static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
+    assert(k % QK4_3 == 0);
+    const int nb = k / QK4_3;
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int l = 0; l < QK4_3; l++) {
+            const float v = x[i*QK4_3 + l];
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
+
+        for (int l = 0; l < QK4_3; l += 2) {
+            const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
+            const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
+
+            const uint8_t vi0 = (int) (v0 + 0.5f);
+            const uint8_t vi1 = (int) (v1 + 0.5f);
+
+            assert(vi0 < 16);
+            assert(vi1 < 16);
+
+            y[i].qs[l/2] = vi0 | (vi1 << 4);
+        }
+    }
+}
+
+static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK4_3 == 0);
+
+    block_q4_3 * restrict y = vy;
+
+    quantize_row_q4_3_reference(x, y, k);
 }
 
 // reference implementation for deterministic creation of model files
@@ -1309,7 +1447,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
 
         for (int l = 0; l < QK4_0; l += 32) {
             // Load 32x4-bit integers into 32x8-bit integers
-            __m256i vx8 = bytesFromNibbles(pp+l/2);
+            __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
 
             // Subtract 8 from the integers
             vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1427,7 +1565,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
 
         for (int l = 0; l < QK4_1; l += 32) {
             // Load 32x4-bit integers into 32x8-bit integers
-            __m256i vx8 = bytesFromNibbles(pp+l/2);
+            __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
 
             // Convert to 16-bit int
             const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -1547,9 +1685,40 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
     }
 }
 
+static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK4_3 == 0);
+    const int nb = k / QK4_3;
+
+    const block_q4_3 * restrict x = vx;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK4_3; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0xf;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = vi0*d + m;
+            const float v1 = vi1*d + m;
+
+            y[i*QK4_3 + l + 0] = v0;
+            y[i*QK4_3 + l + 1] = v1;
+
+            assert(!isnan(y[i*QK4_3 + l + 0]));
+            assert(!isnan(y[i*QK4_3 + l + 1]));
+        }
+    }
+}
+
 static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
 static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = {
@@ -1569,10 +1738,17 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_2] = {
         .dequantize_row_q         = dequantize_row_q4_2,
         .quantize_row_q           = quantize_row_q4_2,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
         .quantize_row_q_dot       = quantize_row_q8_0,
         .vec_dot_q                = ggml_vec_dot_q4_2_q8_0,
     },
+    [GGML_TYPE_Q4_3] = {
+        .dequantize_row_q         = dequantize_row_q4_3,
+        .quantize_row_q           = quantize_row_q4_3,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q4_3_q8_0,
+    },
     [GGML_TYPE_Q8_0] = {
         .dequantize_row_q         = NULL,   // TODO
         .quantize_row_q           = quantize_row_q8_0,
@@ -2270,7 +2446,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         /* Compute combined scale for the block */
         const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
 
-        __m256i bx = bytesFromNibbles(x[i].qs);
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
 
         // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
         const __m256i off = _mm256_set1_epi8( 8 );
@@ -2316,7 +2492,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         __m128i i32[2];
         for (int j = 0; j < 2; ++j) {
             // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
-            __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
+            __m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
             __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
 
             // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
@@ -2481,7 +2657,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
         const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
 
         // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
-        const __m256i bx = bytesFromNibbles( x[i].qs );
+        const __m256i bx = bytes_from_nibbles_32(x[i].qs);
         const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
 
         // Get absolute values of x vectors
@@ -2567,6 +2743,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
         const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
         const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
         const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
+
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
@@ -2635,6 +2812,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
     }
 
     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
+        const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
+        const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
+
+        __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
+        __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
+        __m256i bx = _mm256_set_m128i(bx1, bx0);
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8(8);
+        bx = _mm256_sub_epi8(bx, off);
+
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        // Get absolute values of x vectors
+        const __m256i ax = _mm256_sign_epi8(bx, bx);
+        // Sign the values of the y vectors
+        const __m256i sy = _mm256_sign_epi8(by, bx);
+        // Perform multiplication and create 16-bit values
+        const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+
+        const __m256i ones = _mm256_set1_epi16(1);
+        __m256i xy_q = _mm256_madd_epi16(ones, dot);
+
+        /* Convert to vectore of 8 int32_t to 8 floats */
+        __m256 q = _mm256_cvtepi32_ps(xy_q);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps(d, q, acc);
+    }
+
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps(acc, 1);
+    res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
+    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+    res = _mm_add_ss(res, _mm_movehdup_ps(res));
+
+    sumf = _mm_cvtss_f32(res);
 #else
     // scalar
     for (int i = 0; i < nb; i++) {
@@ -2676,6 +2898,154 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
     *s = sumf;
 }
 
+static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_0;
+
+    assert(n % QK8_0 == 0);
+    assert(nb % 2 == 0);
+    assert(QK8_0 == 2*QK4_2);
+
+    const block_q4_3 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+    float sumf = 0.0;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
+        const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
+        const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
+        const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
+
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+        const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
+        const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
+        const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
+        const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
+
+        const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
+        const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
+        const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
+        const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
+
+        const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
+        const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // interleave
+        const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
+        const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
+        const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
+        const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
+        const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
+
+        const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
+        const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
+#endif
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const uint8_t * restrict x0 = x[2*i + 0].qs;
+        const uint8_t * restrict x1 = x[2*i + 1].qs;
+        const  int8_t * restrict y0 = y[i].qs;
+
+        const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
+        const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
+        const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
+        const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
+
+        int sy_0 = 0;
+        int sy_1 = 0;
+
+        int sxy_0 = 0;
+        int sxy_1 = 0;
+
+        for (int j = 0; j < QK8_0/4; j++) {
+            const uint8_t v0 = x0[j];
+            const uint8_t v1 = x1[j];
+
+            const int x0_0 = v0 & 0xf;
+            const int x1_0 = v0 >> 4;
+
+            const int x0_1 = v1 & 0xf;
+            const int x1_1 = v1 >> 4;
+
+            const int y0_0 = y0[2*j + 0];
+            const int y1_0 = y0[2*j + 1];
+
+            const int y0_1 = y0[2*(j + QK8_0/4) + 0];
+            const int y1_1 = y0[2*(j + QK8_0/4) + 1];
+
+            sy_0 += y0_0 + y1_0;
+            sy_1 += y0_1 + y1_1;
+
+            sxy_0 += x0_0*y0_0 + x1_0*y1_0;
+            sxy_1 += x0_1*y0_1 + x1_1*y1_1;
+        }
+
+        sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
+        sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
+    }
+#endif
+
+    *s = sumf;
+}
+
+
 // compute GGML_VEC_DOT_UNROLL dot products at once
 // xs - x row stride in bytes
 inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -2923,12 +3293,13 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = QK4_0,
     [GGML_TYPE_Q4_1] = QK4_1,
     [GGML_TYPE_Q4_2] = QK4_2,
+    [GGML_TYPE_Q4_3] = QK4_3,
     [GGML_TYPE_Q8_0] = QK8_0,
     [GGML_TYPE_I8]   = 1,
     [GGML_TYPE_I16]  = 1,
     [GGML_TYPE_I32]  = 1,
 };
-static_assert(GGML_TYPE_COUNT == 9, "GGML_BLCK_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
 
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = sizeof(float),
@@ -2936,12 +3307,13 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
     [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
     [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
+    [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
     [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
     [GGML_TYPE_I8]   = sizeof(int8_t),
     [GGML_TYPE_I16]  = sizeof(int16_t),
     [GGML_TYPE_I32]  = sizeof(int32_t),
 };
-static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
 
 
 static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -2950,12 +3322,13 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = "q4_0",
     [GGML_TYPE_Q4_1] = "q4_1",
     [GGML_TYPE_Q4_2] = "q4_2",
+    [GGML_TYPE_Q4_3] = "q4_3",
     [GGML_TYPE_Q8_0] = "q8_0",
     [GGML_TYPE_I8]   = "i8",
     [GGML_TYPE_I16]  = "i16",
     [GGML_TYPE_I32]  = "i32",
 };
-static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_NAME is outdated");
+static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
 
 static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = false,
@@ -2963,12 +3336,13 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = true,
     [GGML_TYPE_Q4_1] = true,
     [GGML_TYPE_Q4_2] = true,
+    [GGML_TYPE_Q4_3] = true,
     [GGML_TYPE_Q8_0] = true,
     [GGML_TYPE_I8]   = false,
     [GGML_TYPE_I16]  = false,
     [GGML_TYPE_I32]  = false,
 };
-static_assert(GGML_TYPE_COUNT == 9, "GGML_IS_QUANTIZED is outdated");
+static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
 
 static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "NONE",
@@ -3230,7 +3604,7 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
         (t0->ne[3] == t1->ne[3]);
 }
 
-static inline bool ggml_is_quantized(enum ggml_type type) {
+bool ggml_is_quantized(enum ggml_type type) {
     return GGML_IS_QUANTIZED[type];
 }
 
@@ -5766,7 +6140,6 @@ static void ggml_compute_forward_dup_f32(
                 i10 += ne00 * ir0;
                 while (i10 >= ne0) {
                     i10 -= ne0;
-                    i11++;
                     if (++i11 == ne1) {
                         i11 = 0;
                         if (++i12 == ne2) {
@@ -6180,6 +6553,7 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
+        case GGML_TYPE_Q4_3:
             {
                 ggml_compute_forward_add_q_f32(params, src0, src1, dst);
             } break;
@@ -7228,7 +7602,6 @@ static void ggml_compute_forward_mul_mat_f32(
 
                 // copy data to host
                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
-                CUDA_CHECK(cudaStreamSynchronize(cudaStream));
 #else
                 // zT = y * xT
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -7240,6 +7613,7 @@ static void ggml_compute_forward_mul_mat_f32(
             }
         }
 #if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaStreamSynchronize(cudaStream));
         CUDA_CHECK(cudaFree(d_X));
         CUDA_CHECK(cudaFree(d_Y));
         CUDA_CHECK(cudaFree(d_D));
@@ -7452,7 +7826,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 
                 // copy data to host
                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
-                CUDA_CHECK(cudaStreamSynchronize(cudaStream));
 #else
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@@ -7470,6 +7843,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaStreamSynchronize(cudaStream));
         CUDA_CHECK(cudaFree(d_X));
         CUDA_CHECK(cudaFree(d_Y));
         CUDA_CHECK(cudaFree(d_D));
@@ -7639,13 +8013,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
             return;
         }
 
-        float * const wdata = params->wdata;
-        dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
-
 #if defined(GGML_USE_CUBLAS)
         float *d_X = NULL;
         float *d_Y = NULL;
         float *d_D = NULL;
+        float *d_Q = NULL;
         const float alpha = 1.0f;
         const float beta = 0.0f;
         const int x_ne = ne01 * ne10;
@@ -7655,10 +8027,41 @@ static void ggml_compute_forward_mul_mat_q_f32(
         CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
         CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
         CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+        CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
+
+        void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream)  = NULL;
+        if (type == GGML_TYPE_Q4_0) {
+            dequantize_row_q_cuda = dequantize_row_q4_0_cuda;
+        }
+        else if (type == GGML_TYPE_Q4_1) {
+            dequantize_row_q_cuda = dequantize_row_q4_1_cuda;
+        }
+        else if (type == GGML_TYPE_Q4_2) {
+            dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
+        }
+        else {
+            GGML_ASSERT(false);
+        }
+#else
+        float * const wdata = params->wdata;
+        dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
 #endif
 
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+#if defined(GGML_USE_CUBLAS)
+                // copy and dequantize on device
+                CUDA_CHECK(
+                    cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
+                        GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
+
+                dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
+                CUDA_CHECK(cudaGetLastError());
+#else
                 {
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne01; ++i01) {
@@ -7666,15 +8069,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
                         id += ne00;
                     }
                 }
-
                 const float * x = wdata;
-                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+#endif
 
-                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
 #if defined(GGML_USE_CUBLAS)
                 // copy data to device
-                CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
                 CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
 
                 // compute
@@ -7687,7 +8087,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
                 // copy data to host
                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
-                CUDA_CHECK(cudaStreamSynchronize(cudaStream));
 #else
                 // zT = y * xT
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -7700,9 +8099,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
         }
 
 #if defined(GGML_USE_CUBLAS)
+        CUDA_CHECK(cudaStreamSynchronize(cudaStream));
         CUDA_CHECK(cudaFree(d_X));
         CUDA_CHECK(cudaFree(d_Y));
         CUDA_CHECK(cudaFree(d_D));
+        CUDA_CHECK(cudaFree(d_Q));
 #endif
         //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
 
@@ -7792,6 +8193,7 @@ static void ggml_compute_forward_mul_mat(
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
+        case GGML_TYPE_Q4_3:
         case GGML_TYPE_Q8_0:
             {
                 ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
@@ -7809,34 +8211,6 @@ static void ggml_compute_forward_mul_mat(
                 GGML_ASSERT(false);
             } break;
     }
-
-#if 0
-    if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
-        static int first = 8;
-        printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
-        printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
-        printf("dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
-        if (first) {
-            --first;
-        } else {
-            for (int k = 0; k < dst->ne[1]; ++k) {
-                for (int j = 0; j < dst->ne[0]/16; ++j) {
-                    for (int i = 0; i < 16; ++i) {
-                        printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
-                    }
-                    printf("\n");
-                }
-                printf("\n");
-            }
-            printf("\n");
-            exit(0);
-        }
-    } else {
-        printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
-        printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
-        printf("aaaa dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
-    }
-#endif
 }
 
 // ggml_compute_forward_scale
@@ -8048,6 +8422,7 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
+        case GGML_TYPE_Q4_3:
         case GGML_TYPE_Q8_0:
             {
                 ggml_compute_forward_get_rows_q(params, src0, src1, dst);
@@ -11770,7 +12145,8 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
     for (int j = 0; j < n; j += k) {
         block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
 
-        quantize_row_q4_2_reference(src + j, y, k);
+        //quantize_row_q4_2_reference(src + j, y, k);
+        quantize_row_q4_2_rmse(src + j, y, k);
 
         for (int i = 0; i < nb; i++) {
             for (int l = 0; l < QK4_2; l += 2) {
@@ -11786,6 +12162,62 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
     return (n/QK4_2*sizeof(block_q4_2));
 }
 
+size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK4_3 == 0);
+    const int nb = k / QK4_3;
+
+    for (int j = 0; j < n; j += k) {
+        block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
+
+        quantize_row_q4_3_reference(src + j, y, k);
+
+        for (int i = 0; i < nb; i++) {
+            for (int l = 0; l < QK4_3; l += 2) {
+                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+                const uint8_t vi1 = y[i].qs[l/2] >> 4;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+        }
+    }
+
+    return (n/QK4_3*sizeof(block_q4_3));
+}
+
+size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist) {
+    size_t result = 0;
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            {
+                GGML_ASSERT(start % QK4_0 == 0);
+                block_q4_0 * block = (block_q4_0*)dst + start / QK4_0;
+                result = ggml_quantize_q4_0(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                GGML_ASSERT(start % QK4_1 == 0);
+                block_q4_1 * block = (block_q4_1*)dst + start / QK4_1;
+                result = ggml_quantize_q4_1(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q4_2:
+            {
+                GGML_ASSERT(start % QK4_2 == 0);
+                block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
+                result = ggml_quantize_q4_2(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q4_3:
+            {
+                GGML_ASSERT(start % QK4_3 == 0);
+                block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
+                result = ggml_quantize_q4_3(src + start, block, n, n, hist);
+            } break;
+        default:
+            assert(false);
+    }
+    return result;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 int ggml_cpu_has_avx(void) {