]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add Q5_0 and Q5_1 quantization (#1187)
authorGeorgi Gerganov <redacted>
Wed, 26 Apr 2023 20:14:13 +0000 (23:14 +0300)
committerGitHub <redacted>
Wed, 26 Apr 2023 20:14:13 +0000 (23:14 +0300)
* ggml : add Q5_0 quantization (cuBLAS only)

* ggml : fix Q5_0 qh -> uint32_t

* ggml : fix q5_0 histogram stats

* ggml : q5_0 scalar dot product

* ggml : q5_0 ARM NEON dot

* ggml : q5_0 more efficient ARM NEON using uint64_t masks

* ggml : rename Q5_0 -> Q5_1

* ggml : adding Q5_0 mode

* quantize : add Q5_0 and Q5_1 to map

* ggml : AVX2 optimizations for Q5_0, Q5_1 (#1195)

---------

Co-authored-by: Stephan Walter <redacted>
.gitignore
examples/quantize/quantize.cpp
ggml-cuda.cu
ggml-cuda.h
ggml.c
ggml.h
llama.cpp
llama.h

index e52d479eeafa8e3c29892341dff70fb12df077f6..c7573bb3b93c42f89ecbe68b5a41f4842c4af56f 100644 (file)
@@ -15,6 +15,7 @@ build-em/
 build-debug/
 build-release/
 build-static/
+build-cublas/
 build-no-accel/
 build-sanitize-addr/
 build-sanitize-thread/
index ec7f91aae6bf029187aab53160f3cd35aef34542..60966595e9561c211b018ae6e20cbfc32f9bad65 100644 (file)
@@ -10,6 +10,8 @@ static const std::map<std::string, enum llama_ftype> LLAMA_FTYPE_MAP = {
   {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1},
   {"q4_2", LLAMA_FTYPE_MOSTLY_Q4_2},
   {"q4_3", LLAMA_FTYPE_MOSTLY_Q4_3},
+  {"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0},
+  {"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1},
   {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0},
 };
 
index f104ed5ac42ddb981f842c437271cb716d7d069c..b1bd29b100d811dc09db743c4c6537e2faf82009 100644 (file)
@@ -37,6 +37,23 @@ typedef struct {
 } block_q4_3;
 static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
 
+#define QK5_0 32
+typedef struct {
+    __half d;               // delta
+    uint8_t qh[4];          // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2];  // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+    __half d;               // delta
+    __half m;               // min
+    uint32_t qh;            // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2];  // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
 #define QK8_0 32
 typedef struct {
     float   d;              // delta
@@ -138,6 +155,64 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
     }
 }
 
+static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+
+    const uint8_t * pp = x[i].qs;
+
+    uint32_t qh;
+    memcpy(&qh, x[i].qh, sizeof(qh));
+
+    for (int l = 0; l < QK5_0; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+        const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+        const int8_t vi0 = ((vi & 0xf) | vh0);
+        const int8_t vi1 = ((vi >>  4) | vh1);
+
+        const float v0 = (vi0 - 16)*d;
+        const float v1 = (vi1 - 16)*d;
+
+        y[i*QK5_0 + l + 0] = v0;
+        y[i*QK5_0 + l + 1] = v1;
+    }
+}
+
+static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const int i = blockIdx.x;
+
+    const float d = x[i].d;
+    const float m = x[i].m;
+
+    const uint8_t * pp = x[i].qs;
+
+    const uint32_t qh = x[i].qh;
+
+    for (int l = 0; l < QK5_1; l += 2) {
+        const uint8_t vi = pp[l/2];
+
+        const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+        const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+        const int8_t vi0 = (vi & 0xf) | vh0;
+        const int8_t vi1 = (vi >>  4) | vh1;
+
+        const float v0 = vi0*d + m;
+        const float v1 = vi1*d + m;
+
+        y[i*QK5_1 + l + 0] = v0;
+        y[i*QK5_1 + l + 1] = v1;
+    }
+}
+
 static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
     const block_q8_0 * x = (const block_q8_0 *) vx;
 
@@ -174,6 +249,16 @@ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t st
     dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
 }
 
+void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK5_0;
+    dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
+}
+
+void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
+    const int nb = k / QK5_1;
+    dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
+}
+
 void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
     const int nb = k / QK8_0;
     dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
index 4048ea49193214873933da663d42c5c6b70043cb..ed9b44184bf566ecb0036f3212f9681e7a3be099 100644 (file)
@@ -35,6 +35,8 @@ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t st
 void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
+void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
 
 #ifdef  __cplusplus
diff --git a/ggml.c b/ggml.c
index 064510edaa798d4107775996afd2e2206eac4220..03b4bd439f29945cb5aeb12f47a564c4c7464e24 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -328,6 +328,19 @@ static ggml_fp16_t table_exp_f16[1 << 16];
 // precomputed f32 table for f16 (256 KB)
 static float table_f32_f16[1 << 16];
 
+#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s
+#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
+#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
+#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
+#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
+#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
+#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
+#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)
+
+// precomputed tables for expanding 8bits to 8 bytes (shl 4)
+static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
+static const uint64_t table_b2b_i[1 << 8] = { B8(F0, 00) };
+
 // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
 // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
 // This is also true for POWER9.
@@ -673,6 +686,23 @@ typedef struct {
 } block_q4_3;
 static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
 
+#define QK5_0 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+    ggml_fp16_t d;         // delta
+    ggml_fp16_t m;         // min
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
 #define QK8_0 32
 typedef struct {
     float   d;          // delta
@@ -1288,6 +1318,103 @@ static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int
     quantize_row_q4_3_reference(x, y, k);
 }
 
+static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
+    assert(k % QK5_0 == 0);
+    const int nb = k / QK5_0;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+        float max = 0.0f;
+
+        for (int l = 0; l < QK5_0; l++) {
+            const float v = x[i*QK5_0 + l];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max = v;
+            }
+        }
+
+        const float d = max / -16;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        uint32_t qh = 0;
+
+        for (int l = 0; l < QK5_0; l += 2) {
+            const float v0 = x[i*QK5_0 + l + 0]*id;
+            const float v1 = x[i*QK5_0 + l + 1]*id;
+
+            const uint32_t vi0 = MIN(31, (int) (v0 + 16.5f));
+            const uint32_t vi1 = MIN(31, (int) (v1 + 16.5f));
+
+            y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((vi0 & 0x10) >> 4) << (l + 0);
+            qh |= ((vi1 & 0x10) >> 4) << (l + 1);
+        }
+
+        memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
+    }
+}
+
+static void quantize_row_q5_0(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK5_0 == 0);
+
+    block_q5_0 * restrict y = vy;
+
+    quantize_row_q5_0_reference(x, y, k);
+}
+
+static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) {
+    assert(k % QK5_1 == 0);
+    const int nb = k / QK5_1;
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int l = 0; l < QK5_1; l++) {
+            const float v = x[i*QK5_1 + l];
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 5) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
+
+        uint32_t qh = 0;
+
+        for (int l = 0; l < QK5_1; l += 2) {
+            const float v0 = (x[i*QK5_1 + l + 0] - min)*id;
+            const float v1 = (x[i*QK5_1 + l + 1] - min)*id;
+
+            const uint32_t vi0 = (int) (v0 + 0.5f);
+            const uint32_t vi1 = (int) (v1 + 0.5f);
+
+            y[i].qs[l/2] = (vi0 & 0x0F) | ((vi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((vi0 & 0x10) >> 4) << (l + 0);
+            qh |= ((vi1 & 0x10) >> 4) << (l + 1);
+        }
+
+        memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
+    }
+}
+
+static void quantize_row_q5_1(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK5_1 == 0);
+
+    block_q5_1 * restrict y = vy;
+
+    quantize_row_q5_1_reference(x, y, k);
+}
+
 // reference implementation for deterministic creation of model files
 static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
     assert(k % QK8_0 == 0);
@@ -1571,7 +1698,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
             const uint8x8_t v8 = vld1_u8(pp + l/2);
 
             // Expand 4-bit qs to 8-bit bytes
-            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
+            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
             const uint8x8_t v1 = vshr_n_u8(v8, 4);
 
             // Convert to signed 8-bit integers
@@ -1621,7 +1748,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
         for (int l = 0; l < QK4_0; l += 2) {
             const uint8_t vi = pp[l/2];
 
-            const int8_t vi0 = vi & 0xf;
+            const int8_t vi0 = vi & 0x0F;
             const int8_t vi1 = vi >> 4;
 
             const float v0 = (vi0 - 8)*d;
@@ -1687,7 +1814,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
             const uint8x8_t v8 = vld1_u8(pp + l/2);
 
             // Expand 4-bit qs to 8-bit bytes
-            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
+            const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
             const uint8x8_t v1 = vshr_n_u8(v8, 4);
 
             // Interleave and combine
@@ -1729,7 +1856,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
         for (int l = 0; l < QK4_1; l += 2) {
             const uint8_t vi = pp[l/2];
 
-            const int8_t vi0 = vi & 0xf;
+            const int8_t vi0 = vi & 0x0F;
             const int8_t vi1 = vi >> 4;
 
             const float v0 = vi0*d + m;
@@ -1759,7 +1886,7 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
         for (int l = 0; l < QK4_2; l += 2) {
             const uint8_t vi = pp[l/2];
 
-            const int8_t vi0 = vi & 0xf;
+            const int8_t vi0 = vi & 0x0F;
             const int8_t vi1 = vi >> 4;
 
             const float v0 = (vi0 - 8)*d;
@@ -1789,7 +1916,7 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
         for (int l = 0; l < QK4_3; l += 2) {
             const uint8_t vi = pp[l/2];
 
-            const int8_t vi0 = vi & 0xf;
+            const int8_t vi0 = vi & 0x0F;
             const int8_t vi1 = vi >> 4;
 
             const float v0 = vi0*d + m;
@@ -1804,6 +1931,79 @@ static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, in
     }
 }
 
+static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK5_0 == 0);
+    const int nb = k / QK5_0;
+
+    const block_q5_0 * restrict x = vx;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int l = 0; l < QK5_0; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            // extract the 5-th bit from qh
+            const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+            const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+            const int8_t vi0 = (vi & 0x0F) | vh0;
+            const int8_t vi1 = (vi >>   4) | vh1;
+
+            const float v0 = (vi0 - 16)*d;
+            const float v1 = (vi1 - 16)*d;
+
+            y[i*QK5_0 + l + 0] = v0;
+            y[i*QK5_0 + l + 1] = v1;
+
+            assert(!isnan(y[i*QK5_0 + l + 0]));
+            assert(!isnan(y[i*QK5_0 + l + 1]));
+        }
+    }
+}
+
+static void dequantize_row_q5_1(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK5_1 == 0);
+    const int nb = k / QK5_1;
+
+    const block_q5_1 * restrict x = vx;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int l = 0; l < QK5_1; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            // extract the 5-th bit from qh
+            const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+            const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+            const uint8_t vi0 = (vi & 0x0F) | vh0;
+            const uint8_t vi1 = (vi >>   4) | vh1;
+
+            const float v0 = vi0*d + m;
+            const float v1 = vi1*d + m;
+
+            y[i*QK5_1 + l + 0] = v0;
+            y[i*QK5_1 + l + 1] = v1;
+
+            assert(!isnan(y[i*QK5_1 + l + 0]));
+            assert(!isnan(y[i*QK5_1 + l + 1]));
+        }
+    }
+}
+
 static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) {
     assert(k % QK8_0 == 0);
     const int nb = k / QK8_0;
@@ -1825,6 +2025,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
 static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
 static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
@@ -1860,6 +2062,22 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
         .vec_dot_q                = ggml_vec_dot_q4_3_q8_1,
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
+    [GGML_TYPE_Q5_0] = {
+        .dequantize_row_q         = dequantize_row_q5_0,
+        .quantize_row_q           = quantize_row_q5_0,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q5_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+    },
+    [GGML_TYPE_Q5_1] = {
+        .dequantize_row_q         = dequantize_row_q5_1,
+        .quantize_row_q           = quantize_row_q5_1,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
+        .quantize_row_q_dot       = quantize_row_q8_1,
+        .vec_dot_q                = ggml_vec_dot_q5_1_q8_1,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+    },
     [GGML_TYPE_Q8_0] = {
         .dequantize_row_q         = dequantize_row_q8_0,
         .quantize_row_q           = quantize_row_q8_0,
@@ -2496,7 +2714,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
-        const uint8x16_t m4b   = vdupq_n_u8(0xf);
+        const uint8x16_t m4b   = vdupq_n_u8(0x0F);
         const int8x16_t  s8b   = vdupq_n_s8(0x8);
 
         const uint8x16_t v0_0 = vld1q_u8(x0->qs);
@@ -2632,8 +2850,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         for (int j = 0; j < QK8_0/2; j++) {
             const uint8_t v0 = p0[j];
 
-            const int i0 = (int8_t) (v0 & 0xf) - 8;
-            const int i1 = (int8_t) (v0 >> 4)  - 8;
+            const int i0 = (int8_t) (v0 & 0x0F) - 8;
+            const int i1 = (int8_t) (v0 >>   4) - 8;
 
             const int i2 = p1[2*j + 0];
             const int i3 = p1[2*j + 1];
@@ -2670,7 +2888,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
 
         summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1);
 
-        const uint8x16_t m4b = vdupq_n_u8(0xf);
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
 
         const uint8x16_t v0_0 = vld1q_u8(x0->qs);
         const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2767,8 +2985,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         for (int j = 0; j < QK8_1/2; j++) {
             const uint8_t v0 = p0[j];
 
-            const float f0 = d0*(v0 & 0xf) + m0;
-            const float f1 = d0*(v0 >> 4)  + m0;
+            const float f0 = d0*(v0 & 0x0F) + m0;
+            const float f1 = d0*(v0 >>   4) + m0;
 
             const float f2 = d1*p1[2*j + 0];
             const float f3 = d1*p1[2*j + 1];
@@ -2803,7 +3021,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
-        const uint8x16_t m4b   = vdupq_n_u8(0xf);
+        const uint8x16_t m4b   = vdupq_n_u8(0x0F);
         const int8x16_t  s8b   = vdupq_n_s8(0x8);
 
         const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
@@ -2914,11 +3132,11 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
             const uint8_t v0 = x0[j];
             const uint8_t v1 = x1[j];
 
-            const int i0_0 = (int8_t) (v0 & 0xf) - 8;
-            const int i1_0 = (int8_t) (v0 >> 4)  - 8;
+            const int i0_0 = (int8_t) (v0 & 0x0F) - 8;
+            const int i1_0 = (int8_t) (v0 >>   4) - 8;
 
-            const int i0_1 = (int8_t) (v1 & 0xf) - 8;
-            const int i1_1 = (int8_t) (v1 >> 4)  - 8;
+            const int i0_1 = (int8_t) (v1 & 0x0F) - 8;
+            const int i1_1 = (int8_t) (v1 >>   4) - 8;
 
             const int i2_0 = y0[2*j + 0];
             const int i3_0 = y0[2*j + 1];
@@ -2966,7 +3184,7 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void *
         const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
 
         // 4-bit -> 8-bit
-        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, vdupq_n_u8(0xf)));
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, vdupq_n_u8(0x0F)));
         const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
 
         // interleave
@@ -3045,10 +3263,10 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void *
             const uint8_t v0 = x0[j];
             const uint8_t v1 = x1[j];
 
-            const int x0_0 = v0 & 0xf;
+            const int x0_0 = v0 & 0x0F;
             const int x1_0 = v0 >> 4;
 
-            const int x0_1 = v1 & 0xf;
+            const int x0_1 = v1 & 0x0F;
             const int x1_1 = v1 >> 4;
 
             const int y0_0 = y0[2*j + 0];
@@ -3067,6 +3285,273 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void *
 #endif
 }
 
+static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_0;
+
+    assert(n % QK8_0 == 0);
+    assert(nb % 2 == 0);
+    assert(QK8_0 == QK5_0);
+
+    const block_q5_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv = vdupq_n_f32(0.0f);
+
+    uint64_t tmp[4];
+
+    for (int i = 0; i < nb; ++i) {
+        const block_q5_0 * restrict x0 = &x[i];
+        const block_q8_0 * restrict y0 = &y[i];
+
+        const uint8x16_t m4b  = vdupq_n_u8(0x0F);
+        const int8x16_t  s16b = vdupq_n_s8(0x10);
+
+        // extract the 5th bit
+        uint32_t qh;
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_u[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_u[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_u[(qh >> 24)       ];
+
+        const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
+        const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
+
+        const uint8x16_t v0 = vld1q_u8(x0->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8  (v0, m4b));
+        const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
+
+        // interleave
+        const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
+        const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
+
+        // add high bit and sub 16
+        const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0lz, qhl), s16b);
+        const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0hz, qhh), s16b);
+
+        // load y
+        const int8x16_t v1l = vld1q_s8(y0->qs);
+        const int8x16_t v1h = vld1q_s8(y0->qs + 16);
+
+        const float x0d = GGML_FP16_TO_FP32(x0->d);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
+                        vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+
+        sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
+#endif
+    }
+
+    *s = vaddvq_f32(sumv);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i bxhi = _mm256_set_epi64x(
+            table_b2b_i[x[i].qh[3]], table_b2b_i[x[i].qh[2]],
+            table_b2b_i[x[i].qh[1]], table_b2b_i[x[i].qh[0]]);
+        bx = _mm256_or_si256(bx, bxhi);
+
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps(d, q, acc);
+    }
+
+    *s = hsum_float_8(acc);
+#else
+    // scalar
+    float sumf = 0.0;
+    for (int i = 0; i < nb; i++) {
+        const uint8_t * restrict x0 = x[i].qs;
+        const  int8_t * restrict y0 = y[i].qs;
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        int sxy = 0;
+
+        for (int j = 0; j < QK8_0/2; j++) {
+            const uint8_t v0 = x0[j];
+
+            const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
+            const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
+
+            const int x0_0 = ((v0 & 0x0F) | x0_0h) - 16;
+            const int x1_0 = ((v0 >>   4) | x1_0h) - 16;
+
+            const int y0_0 = y0[2*j + 0];
+            const int y1_0 = y0[2*j + 1];
+
+            sxy += x0_0*y0_0 + x1_0*y1_0;
+        }
+
+        sumf += (d*sxy)*y[i].d;
+    }
+    *s = sumf;
+#endif
+}
+
+static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_1;
+
+    assert(n % QK8_1 == 0);
+    assert(nb % 2 == 0);
+    assert(QK8_1 == QK5_1);
+
+    const block_q5_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv = vdupq_n_f32(0.0f);
+
+    float summs = 0.0f;
+
+    uint64_t tmp[4];
+
+    for (int i = 0; i < nb; ++i) {
+        const block_q5_1 * restrict x0 = &x[i];
+        const block_q8_1 * restrict y0 = &y[i];
+
+        summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
+
+        // extract the 5th bit
+        uint32_t qh;
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_u[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_u[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_u[(qh >> 24)       ];
+
+        const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
+        const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
+
+        const uint8x16_t v0 = vld1q_u8(x0->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8  (v0, vdupq_n_u8(0x0F)));
+        const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
+
+        // interleave
+        const int8x16_t v0lz = vzip1q_s8(v0l, v0h);
+        const int8x16_t v0hz = vzip2q_s8(v0l, v0h);
+
+        // add
+        const int8x16_t v0lf = vorrq_s8(v0lz, qhl);
+        const int8x16_t v0hf = vorrq_s8(v0hz, qhh);
+
+        // load y
+        const int8x16_t v1l = vld1q_s8(y0->qs);
+        const int8x16_t v1h = vld1q_s8(y0->qs + 16);
+
+        const float x0d = GGML_FP16_TO_FP32(x0->d);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
+                        vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
+                        vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+
+        sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
+#endif
+    }
+
+    *s = vaddvq_f32(sumv) + summs;
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    float summs = 0.0f;
+
+    // Main loop
+    for (int i = 0; i < nb; i++) {
+        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
+
+        summs += GGML_FP16_TO_FP32(x[i].m) * (y[i].s0 + y[i].s1);
+
+        __m256i bx = bytes_from_nibbles_32(x[i].qs);
+        const __m256i bxhi = _mm256_set_epi64x(
+            table_b2b_u[x[i].qh[3]], table_b2b_u[x[i].qh[2]],
+            table_b2b_u[x[i].qh[1]], table_b2b_u[x[i].qh[0]]);
+        bx = _mm256_or_si256(bx, bxhi);
+
+        const __m256 dy = _mm256_broadcast_ss(&y[i].d);
+        const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
+    }
+
+    *s = hsum_float_8(acc) + summs;
+#else
+    float sumf = 0.0;
+
+    for (int i = 0; i < nb; i++) {
+        const uint8_t * restrict x0 = x[i].qs;
+        const  int8_t * restrict y0 = y[i].qs;
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
+
+        int sxy = 0;
+
+        for (int j = 0; j < QK8_1/2; j++) {
+            const uint8_t v0 = x0[j];
+
+            const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4;
+            const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4;
+
+            const int x0_0 = (v0 & 0x0F) | x0_0h;
+            const int x1_0 = (v0 >>   4) | x1_0h;
+
+            const int y0_0 = y0[2*j + 0];
+            const int y1_0 = y0[2*j + 1];
+
+            sxy += x0_0*y0_0 + x1_0*y1_0;
+        }
+
+        sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1);
+    }
+
+    *s = sumf;
+#endif
+}
+
 static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const int nb = n / QK8_0;
 
@@ -3409,13 +3894,15 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_1] = QK4_1,
     [GGML_TYPE_Q4_2] = QK4_2,
     [GGML_TYPE_Q4_3] = QK4_3,
+    [GGML_TYPE_Q5_0] = QK5_0,
+    [GGML_TYPE_Q5_1] = QK5_1,
     [GGML_TYPE_Q8_0] = QK8_0,
     [GGML_TYPE_Q8_1] = QK8_1,
     [GGML_TYPE_I8]   = 1,
     [GGML_TYPE_I16]  = 1,
     [GGML_TYPE_I32]  = 1,
 };
-static_assert(GGML_TYPE_COUNT == 11, "GGML_BLCK_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 13, "GGML_BLCK_SIZE is outdated");
 
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = sizeof(float),
@@ -3424,13 +3911,15 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
     [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
     [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
+    [GGML_TYPE_Q5_0] = sizeof(block_q5_0),
+    [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
     [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
     [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
     [GGML_TYPE_I8]   = sizeof(int8_t),
     [GGML_TYPE_I16]  = sizeof(int16_t),
     [GGML_TYPE_I32]  = sizeof(int32_t),
 };
-static_assert(GGML_TYPE_COUNT == 11, "GGML_TYPE_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_SIZE is outdated");
 
 
 static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3440,13 +3929,15 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_1] = "q4_1",
     [GGML_TYPE_Q4_2] = "q4_2",
     [GGML_TYPE_Q4_3] = "q4_3",
+    [GGML_TYPE_Q5_0] = "q5_0",
+    [GGML_TYPE_Q5_1] = "q5_1",
     [GGML_TYPE_Q8_0] = "q8_0",
     [GGML_TYPE_Q8_1] = "q8_1",
     [GGML_TYPE_I8]   = "i8",
     [GGML_TYPE_I16]  = "i16",
     [GGML_TYPE_I32]  = "i32",
 };
-static_assert(GGML_TYPE_COUNT == 11, "GGML_TYPE_NAME is outdated");
+static_assert(GGML_TYPE_COUNT == 13, "GGML_TYPE_NAME is outdated");
 
 static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = false,
@@ -3455,13 +3946,15 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_1] = true,
     [GGML_TYPE_Q4_2] = true,
     [GGML_TYPE_Q4_3] = true,
+    [GGML_TYPE_Q5_0] = true,
+    [GGML_TYPE_Q5_1] = true,
     [GGML_TYPE_Q8_0] = true,
     [GGML_TYPE_Q8_1] = true,
     [GGML_TYPE_I8]   = false,
     [GGML_TYPE_I16]  = false,
     [GGML_TYPE_I32]  = false,
 };
-static_assert(GGML_TYPE_COUNT == 11, "GGML_IS_QUANTIZED is outdated");
+static_assert(GGML_TYPE_COUNT == 13, "GGML_IS_QUANTIZED is outdated");
 
 static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "NONE",
@@ -6673,6 +7166,8 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
         case GGML_TYPE_Q4_3:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
             {
                 ggml_compute_forward_add_q_f32(params, src0, src1, dst);
@@ -8161,6 +8656,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
         else if (type == GGML_TYPE_Q4_3) {
             dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
         }
+        else if (type == GGML_TYPE_Q5_0) {
+            dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
+        }
+        else if (type == GGML_TYPE_Q5_1) {
+            dequantize_row_q_cuda = dequantize_row_q5_1_cuda;
+        }
         else if (type == GGML_TYPE_Q8_0) {
             dequantize_row_q_cuda = dequantize_row_q8_0_cuda;
         }
@@ -8319,6 +8820,8 @@ static void ggml_compute_forward_mul_mat(
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
         case GGML_TYPE_Q4_3:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
             {
@@ -8549,6 +9052,8 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
         case GGML_TYPE_Q4_3:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
             {
@@ -12261,7 +12766,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
 
         for (int i = 0; i < nb; i++) {
             for (int l = 0; l < QK4_0; l += 2) {
-                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+                const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
                 const uint8_t vi1 = y[i].qs[l/2] >> 4;
 
                 hist[vi0]++;
@@ -12284,7 +12789,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
 
         for (int i = 0; i < nb; i++) {
             for (int l = 0; l < QK4_1; l += 2) {
-                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+                const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
                 const uint8_t vi1 = y[i].qs[l/2] >> 4;
 
                 hist[vi0]++;
@@ -12307,7 +12812,7 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
 
         for (int i = 0; i < nb; i++) {
             for (int l = 0; l < QK4_2; l += 2) {
-                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+                const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
                 const uint8_t vi1 = y[i].qs[l/2] >> 4;
 
                 hist[vi0]++;
@@ -12330,7 +12835,7 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
 
         for (int i = 0; i < nb; i++) {
             for (int l = 0; l < QK4_3; l += 2) {
-                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+                const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
                 const uint8_t vi1 = y[i].qs[l/2] >> 4;
 
                 hist[vi0]++;
@@ -12342,6 +12847,66 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
     return (n/QK4_3*sizeof(block_q4_3));
 }
 
+size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK5_0 == 0);
+    const int nb = k / QK5_0;
+
+    for (int j = 0; j < n; j += k) {
+        block_q5_0 * restrict y = (block_q5_0 *)dst + j/QK5_0;
+
+        quantize_row_q5_0_reference(src + j, y, k);
+
+        for (int i = 0; i < nb; i++) {
+            uint32_t qh;
+            memcpy(&qh, &y[i].qh, sizeof(qh));
+
+            for (int l = 0; l < QK5_0; l += 2) {
+                const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+                const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+                // cast to 16 bins
+                const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
+                const uint8_t vi1 = ((y[i].qs[l/2] >>   4) | vh1) / 2;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+        }
+    }
+
+    return (n/QK5_0*sizeof(block_q5_0));
+}
+
+size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK5_1 == 0);
+    const int nb = k / QK5_1;
+
+    for (int j = 0; j < n; j += k) {
+        block_q5_1 * restrict y = (block_q5_1 *)dst + j/QK5_1;
+
+        quantize_row_q5_1_reference(src + j, y, k);
+
+        for (int i = 0; i < nb; i++) {
+            uint32_t qh;
+            memcpy(&qh, &y[i].qh, sizeof(qh));
+
+            for (int l = 0; l < QK5_1; l += 2) {
+                const uint8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4;
+                const uint8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4;
+
+                // cast to 16 bins
+                const uint8_t vi0 = ((y[i].qs[l/2] & 0x0F) | vh0) / 2;
+                const uint8_t vi1 = ((y[i].qs[l/2] >>   4) | vh1) / 2;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+        }
+    }
+
+    return (n/QK5_1*sizeof(block_q5_1));
+}
+
 size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist) {
     assert(k % QK8_0 == 0);
     const int nb = k / QK8_0;
@@ -12390,6 +12955,18 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
                 block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
                 result = ggml_quantize_q4_3(src + start, block, n, n, hist);
             } break;
+        case GGML_TYPE_Q5_0:
+            {
+                GGML_ASSERT(start % QK5_0 == 0);
+                block_q5_0 * block = (block_q5_0*)dst + start / QK5_0;
+                result = ggml_quantize_q5_0(src + start, block, n, n, hist);
+            } break;
+        case GGML_TYPE_Q5_1:
+            {
+                GGML_ASSERT(start % QK5_1 == 0);
+                block_q5_1 * block = (block_q5_1*)dst + start / QK5_1;
+                result = ggml_quantize_q5_1(src + start, block, n, n, hist);
+            } break;
         case GGML_TYPE_Q8_0:
             {
                 GGML_ASSERT(start % QK8_0 == 0);
diff --git a/ggml.h b/ggml.h
index 8300a0c62db9b364e28f1c2d1c972834be486756..d9d3d214e84e70f827d59b11ec2a35d47fb9b26f 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -222,8 +222,10 @@ extern "C" {
         GGML_TYPE_Q4_1 = 3,
         GGML_TYPE_Q4_2 = 4,
         GGML_TYPE_Q4_3 = 5,
-        GGML_TYPE_Q8_0 = 6,
-        GGML_TYPE_Q8_1 = 7,
+        GGML_TYPE_Q5_0 = 6,
+        GGML_TYPE_Q5_1 = 7,
+        GGML_TYPE_Q8_0 = 8,
+        GGML_TYPE_Q8_1 = 9,
         GGML_TYPE_I8,
         GGML_TYPE_I16,
         GGML_TYPE_I32,
@@ -833,6 +835,8 @@ extern "C" {
     GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
     GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
 
     GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
index 8334553a591dcae6559387239332e094fee4eb09..28a74b514b852488a996e3aec0d4142da027879a 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -484,6 +484,8 @@ struct llama_file_loader {
                 case GGML_TYPE_Q4_1:
                 case GGML_TYPE_Q4_2:
                 case GGML_TYPE_Q4_3:
+                case GGML_TYPE_Q5_0:
+                case GGML_TYPE_Q5_1:
                 case GGML_TYPE_Q8_0:
                     break;
                 default: {
@@ -559,6 +561,8 @@ struct llama_file_saver {
             case GGML_TYPE_Q4_1:
             case GGML_TYPE_Q4_2:
             case GGML_TYPE_Q4_3:
+            case GGML_TYPE_Q5_0:
+            case GGML_TYPE_Q5_1:
             case GGML_TYPE_Q8_0:
                 break;
             default: LLAMA_ASSERT(false);
@@ -850,6 +854,8 @@ static const char *llama_ftype_name(enum llama_ftype ftype) {
                                       return "mostly Q4_1, some F16";
         case LLAMA_FTYPE_MOSTLY_Q4_2: return "mostly Q4_2";
         case LLAMA_FTYPE_MOSTLY_Q4_3: return "mostly Q4_3";
+        case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0";
+        case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1";
         case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0";
         default:                      return "unknown, may not work";
     }
@@ -1588,6 +1594,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break;
         case LLAMA_FTYPE_MOSTLY_Q4_2: quantized_type = GGML_TYPE_Q4_2; break;
         case LLAMA_FTYPE_MOSTLY_Q4_3: quantized_type = GGML_TYPE_Q4_3; break;
+        case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break;
+        case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break;
         case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break;
         default: throw format("invalid output file type %d\n", ftype);
     };
diff --git a/llama.h b/llama.h
index 24c48cce624d8c43dbfb5d8c772d60a936a131af..17dac0689fbb5cb90e47f477fbfc0525fe84baad 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -75,6 +75,8 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_Q4_2 = 5,  // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_3 = 6,  // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q8_0 = 7,  // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q5_0 = 8,  // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q5_1 = 9,  // except 1d tensors
     };
 
     LLAMA_API struct llama_context_params llama_context_default_params();