]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add Q4_3 quantization (#1082)
authorGeorgi Gerganov <redacted>
Thu, 20 Apr 2023 17:35:53 +0000 (20:35 +0300)
committerGitHub <redacted>
Thu, 20 Apr 2023 17:35:53 +0000 (20:35 +0300)
examples/quantize/quantize.cpp
ggml.c
ggml.h
llama.cpp
llama.h

index 59cb6744016cb733ac83ebca51d8927bfbdff2d1..49a33a86f10a8c1aab2746992746fb7e502a46be 100644 (file)
@@ -15,6 +15,7 @@ int main(int argc, char ** argv) {
         fprintf(stderr, "  type = %d - q4_0\n", LLAMA_FTYPE_MOSTLY_Q4_0);
         fprintf(stderr, "  type = %d - q4_1\n", LLAMA_FTYPE_MOSTLY_Q4_1);
         fprintf(stderr, "  type = %d - q4_2\n", LLAMA_FTYPE_MOSTLY_Q4_2);
+        fprintf(stderr, "  type = %d - q4_3\n", LLAMA_FTYPE_MOSTLY_Q4_3);
         return 1;
     }
 
diff --git a/ggml.c b/ggml.c
index 35b15cc2ee7805bed7875bfdaf90b0ad2f1fc86f..733ddc0dea0b212345d6002d76d114f326a88995 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -637,7 +637,7 @@ typedef struct {
     float   m;          // min
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
-static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
+static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
 
 #define QK4_2 16
 typedef struct {
@@ -646,6 +646,14 @@ typedef struct {
 } block_q4_2;
 static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
 
+#define QK4_3 16
+typedef struct {
+    ggml_fp16_t d;         // delta
+    ggml_fp16_t m;         // min
+    uint8_t qs[QK4_3 / 2]; // nibbles / quants
+} block_q4_3;
+static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
+
 #define QK8_0 32
 typedef struct {
     float   d;          // delta
@@ -1203,7 +1211,6 @@ static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restri
     const int nb = k / QK4_2;
 
     for (int i = 0; i < nb; i++) {
-
         float scale = kquantize_q4_with_bounds(QK4_2, -8, 7, x, CANDIDATE_COUNT, candidates, L);
         y[i].d = GGML_FP32_TO_FP16(scale);
 
@@ -1231,6 +1238,49 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int
     quantize_row_q4_2_rmse(x, y, k);
 }
 
+static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
+    assert(k % QK4_3 == 0);
+    const int nb = k / QK4_3;
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int l = 0; l < QK4_3; l++) {
+            const float v = x[i*QK4_3 + l];
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
+
+        for (int l = 0; l < QK4_3; l += 2) {
+            const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
+            const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
+
+            const uint8_t vi0 = (int) (v0 + 0.5f);
+            const uint8_t vi1 = (int) (v1 + 0.5f);
+
+            assert(vi0 < 16);
+            assert(vi1 < 16);
+
+            y[i].qs[l/2] = vi0 | (vi1 << 4);
+        }
+    }
+}
+
+static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
+    assert(k % QK4_3 == 0);
+
+    block_q4_3 * restrict y = vy;
+
+    quantize_row_q4_3_reference(x, y, k);
+}
+
 // reference implementation for deterministic creation of model files
 static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
     assert(k % QK8_0 == 0);
@@ -1635,9 +1685,40 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
     }
 }
 
+static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
+    assert(k % QK4_3 == 0);
+    const int nb = k / QK4_3;
+
+    const block_q4_3 * restrict x = vx;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
+
+        const uint8_t * restrict pp = x[i].qs;
+
+        for (int l = 0; l < QK4_3; l += 2) {
+            const uint8_t vi = pp[l/2];
+
+            const int8_t vi0 = vi & 0xf;
+            const int8_t vi1 = vi >> 4;
+
+            const float v0 = vi0*d + m;
+            const float v1 = vi1*d + m;
+
+            y[i*QK4_3 + l + 0] = v0;
+            y[i*QK4_3 + l + 1] = v1;
+
+            assert(!isnan(y[i*QK4_3 + l + 0]));
+            assert(!isnan(y[i*QK4_3 + l + 1]));
+        }
+    }
+}
+
 static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
 static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = {
@@ -1661,6 +1742,13 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
         .quantize_row_q_dot       = quantize_row_q8_0,
         .vec_dot_q                = ggml_vec_dot_q4_2_q8_0,
     },
+    [GGML_TYPE_Q4_3] = {
+        .dequantize_row_q         = dequantize_row_q4_3,
+        .quantize_row_q           = quantize_row_q4_3,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference, // TODO: RMSE optimization
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q4_3_q8_0,
+    },
     [GGML_TYPE_Q8_0] = {
         .dequantize_row_q         = NULL,   // TODO
         .quantize_row_q           = quantize_row_q8_0,
@@ -2655,6 +2743,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
         const block_q4_2 * restrict x0_1 = &x[2*(i + 0) + 1];
         const block_q4_2 * restrict x1_0 = &x[2*(i + 1) + 0];
         const block_q4_2 * restrict x1_1 = &x[2*(i + 1) + 1];
+
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
@@ -2809,6 +2898,154 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
     *s = sumf;
 }
 
+static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_0;
+
+    assert(n % QK8_0 == 0);
+    assert(nb % 2 == 0);
+    assert(QK8_0 == 2*QK4_2);
+
+    const block_q4_3 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+    float sumf = 0.0;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
+        const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
+        const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
+        const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
+
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+        const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
+        const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
+        const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
+        const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
+
+        const float x0_0m = GGML_FP16_TO_FP32(x0_0->m);
+        const float x0_1m = GGML_FP16_TO_FP32(x0_1->m);
+        const float x1_0m = GGML_FP16_TO_FP32(x1_0->m);
+        const float x1_1m = GGML_FP16_TO_FP32(x1_1->m);
+
+        const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
+        const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // interleave
+        const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
+        const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
+        const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
+        const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l)));
+        const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h)));
+
+        const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l)));
+        const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h)));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
+#endif
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const uint8_t * restrict x0 = x[2*i + 0].qs;
+        const uint8_t * restrict x1 = x[2*i + 1].qs;
+        const  int8_t * restrict y0 = y[i].qs;
+
+        const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
+        const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
+        const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
+        const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
+
+        int sy_0 = 0;
+        int sy_1 = 0;
+
+        int sxy_0 = 0;
+        int sxy_1 = 0;
+
+        for (int j = 0; j < QK8_0/4; j++) {
+            const uint8_t v0 = x0[j];
+            const uint8_t v1 = x1[j];
+
+            const int x0_0 = v0 & 0xf;
+            const int x1_0 = v0 >> 4;
+
+            const int x0_1 = v1 & 0xf;
+            const int x1_1 = v1 >> 4;
+
+            const int y0_0 = y0[2*j + 0];
+            const int y1_0 = y0[2*j + 1];
+
+            const int y0_1 = y0[2*(j + QK8_0/4) + 0];
+            const int y1_1 = y0[2*(j + QK8_0/4) + 1];
+
+            sy_0 += y0_0 + y1_0;
+            sy_1 += y0_1 + y1_1;
+
+            sxy_0 += x0_0*y0_0 + x1_0*y1_0;
+            sxy_1 += x0_1*y0_1 + x1_1*y1_1;
+        }
+
+        sumf += (d0*sxy_0 + m0*sy_0)*y[i].d;
+        sumf += (d1*sxy_1 + m1*sy_1)*y[i].d;
+    }
+#endif
+
+    *s = sumf;
+}
+
+
 // compute GGML_VEC_DOT_UNROLL dot products at once
 // xs - x row stride in bytes
 inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -3056,12 +3293,13 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = QK4_0,
     [GGML_TYPE_Q4_1] = QK4_1,
     [GGML_TYPE_Q4_2] = QK4_2,
+    [GGML_TYPE_Q4_3] = QK4_3,
     [GGML_TYPE_Q8_0] = QK8_0,
     [GGML_TYPE_I8]   = 1,
     [GGML_TYPE_I16]  = 1,
     [GGML_TYPE_I32]  = 1,
 };
-static_assert(GGML_TYPE_COUNT == 9, "GGML_BLCK_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 10, "GGML_BLCK_SIZE is outdated");
 
 static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = sizeof(float),
@@ -3069,12 +3307,13 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
     [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
     [GGML_TYPE_Q4_2] = sizeof(block_q4_2),
+    [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
     [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
     [GGML_TYPE_I8]   = sizeof(int8_t),
     [GGML_TYPE_I16]  = sizeof(int16_t),
     [GGML_TYPE_I32]  = sizeof(int32_t),
 };
-static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_SIZE is outdated");
 
 
 static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -3083,12 +3322,13 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = "q4_0",
     [GGML_TYPE_Q4_1] = "q4_1",
     [GGML_TYPE_Q4_2] = "q4_2",
+    [GGML_TYPE_Q4_3] = "q4_3",
     [GGML_TYPE_Q8_0] = "q8_0",
     [GGML_TYPE_I8]   = "i8",
     [GGML_TYPE_I16]  = "i16",
     [GGML_TYPE_I32]  = "i32",
 };
-static_assert(GGML_TYPE_COUNT == 9, "GGML_TYPE_NAME is outdated");
+static_assert(GGML_TYPE_COUNT == 10, "GGML_TYPE_NAME is outdated");
 
 static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32]  = false,
@@ -3096,12 +3336,13 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
     [GGML_TYPE_Q4_0] = true,
     [GGML_TYPE_Q4_1] = true,
     [GGML_TYPE_Q4_2] = true,
+    [GGML_TYPE_Q4_3] = true,
     [GGML_TYPE_Q8_0] = true,
     [GGML_TYPE_I8]   = false,
     [GGML_TYPE_I16]  = false,
     [GGML_TYPE_I32]  = false,
 };
-static_assert(GGML_TYPE_COUNT == 9, "GGML_IS_QUANTIZED is outdated");
+static_assert(GGML_TYPE_COUNT == 10, "GGML_IS_QUANTIZED is outdated");
 
 static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "NONE",
@@ -3363,7 +3604,7 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
         (t0->ne[3] == t1->ne[3]);
 }
 
-static inline bool ggml_is_quantized(enum ggml_type type) {
+bool ggml_is_quantized(enum ggml_type type) {
     return GGML_IS_QUANTIZED[type];
 }
 
@@ -6313,6 +6554,7 @@ static void ggml_compute_forward_add(
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
+        case GGML_TYPE_Q4_3:
             {
                 ggml_compute_forward_add_q_f32(params, src0, src1, dst);
             } break;
@@ -7798,6 +8040,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
         else if (type == GGML_TYPE_Q4_2) {
             dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
         }
+        else if (type == GGML_TYPE_Q4_3) {
+            dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
+        }
         else {
             GGML_ASSERT(false);
         }
@@ -7952,6 +8197,7 @@ static void ggml_compute_forward_mul_mat(
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
+        case GGML_TYPE_Q4_3:
         case GGML_TYPE_Q8_0:
             {
                 ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
@@ -7969,34 +8215,6 @@ static void ggml_compute_forward_mul_mat(
                 GGML_ASSERT(false);
             } break;
     }
-
-#if 0
-    if (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_Q4_1) {
-        static int first = 8;
-        printf("src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
-        printf("src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
-        printf("dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
-        if (first) {
-            --first;
-        } else {
-            for (int k = 0; k < dst->ne[1]; ++k) {
-                for (int j = 0; j < dst->ne[0]/16; ++j) {
-                    for (int i = 0; i < 16; ++i) {
-                        printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
-                    }
-                    printf("\n");
-                }
-                printf("\n");
-            }
-            printf("\n");
-            exit(0);
-        }
-    } else {
-        printf("aaaa src0: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src0->ne[0], src0->ne[1], src0->ne[2]);
-        printf("aaaa src1: ne0 = %5d, ne1 = %5d, ne2 = %5d\n", src1->ne[0], src1->ne[1], src1->ne[2]);
-        printf("aaaa dst:  ne0 = %5d, ne1 = %5d, ne2 = %5d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
-    }
-#endif
 }
 
 // ggml_compute_forward_scale
@@ -8208,6 +8426,7 @@ static void ggml_compute_forward_get_rows(
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q4_2:
+        case GGML_TYPE_Q4_3:
         case GGML_TYPE_Q8_0:
             {
                 ggml_compute_forward_get_rows_q(params, src0, src1, dst);
@@ -11947,6 +12166,29 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
     return (n/QK4_2*sizeof(block_q4_2));
 }
 
+size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
+    assert(k % QK4_3 == 0);
+    const int nb = k / QK4_3;
+
+    for (int j = 0; j < n; j += k) {
+        block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
+
+        quantize_row_q4_3_reference(src + j, y, k);
+
+        for (int i = 0; i < nb; i++) {
+            for (int l = 0; l < QK4_3; l += 2) {
+                const uint8_t vi0 = y[i].qs[l/2] & 0xF;
+                const uint8_t vi1 = y[i].qs[l/2] >> 4;
+
+                hist[vi0]++;
+                hist[vi1]++;
+            }
+        }
+    }
+
+    return (n/QK4_3*sizeof(block_q4_3));
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 int ggml_cpu_has_avx(void) {
diff --git a/ggml.h b/ggml.h
index 570147fc246587603042beed51a0cdb61bae17f1..6e81d8125ae6f1b7239ff54e0d6c356c26cae964 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -205,7 +205,8 @@ enum ggml_type {
     GGML_TYPE_Q4_0 = 2,
     GGML_TYPE_Q4_1 = 3,
     GGML_TYPE_Q4_2 = 4,
-    GGML_TYPE_Q8_0 = 5,
+    GGML_TYPE_Q4_3 = 5,
+    GGML_TYPE_Q8_0 = 6,
     GGML_TYPE_I8,
     GGML_TYPE_I16,
     GGML_TYPE_I32,
@@ -360,6 +361,8 @@ const char * ggml_type_name(enum ggml_type type);
 
 size_t ggml_element_size(const struct ggml_tensor * tensor);
 
+bool ggml_is_quantized(enum ggml_type type);
+
 struct ggml_context * ggml_init(struct ggml_init_params params);
 void ggml_free(struct ggml_context * ctx);
 
@@ -808,6 +811,7 @@ enum ggml_opt_result ggml_opt(
 size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
 size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
 size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist);
+size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist);
 
 //
 // system info
index 3ff5dc1e14800952e24b19647c87883c6803900f..99d29a1ef7b092790e26f01aba4735f9d3a1d65b 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -479,6 +479,7 @@ struct llama_file_loader {
                 case GGML_TYPE_Q4_0:
                 case GGML_TYPE_Q4_1:
                 case GGML_TYPE_Q4_2:
+                case GGML_TYPE_Q4_3:
                     break;
                 default: {
                     throw format("unrecognized tensor type %u\n", shard.type);
@@ -552,6 +553,7 @@ struct llama_file_saver {
             case GGML_TYPE_Q4_0:
             case GGML_TYPE_Q4_1:
             case GGML_TYPE_Q4_2:
+            case GGML_TYPE_Q4_3:
                 break;
             default: LLAMA_ASSERT(false);
         }
@@ -841,6 +843,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) {
         case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
                                       return "mostly Q4_1, some F16";
         case LLAMA_FTYPE_MOSTLY_Q4_2: return "mostly Q4_2";
+        case LLAMA_FTYPE_MOSTLY_Q4_3: return "mostly Q4_3";
         default:                      return "unknown, may not work";
     }
 }
@@ -1575,6 +1578,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break;
         case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break;
         case LLAMA_FTYPE_MOSTLY_Q4_2: quantized_type = GGML_TYPE_Q4_2; break;
+        case LLAMA_FTYPE_MOSTLY_Q4_3: quantized_type = GGML_TYPE_Q4_3; break;
         default: throw format("invalid output file type %d\n", ftype);
     };
 
@@ -1652,6 +1656,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                     {
                         new_size = ggml_quantize_q4_2(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data());
                     } break;
+                case GGML_TYPE_Q4_3:
+                    {
+                        new_size = ggml_quantize_q4_3(f32_data, new_data, nelements, (int) tensor.ne.at(0), hist_cur.data());
+                    } break;
                 default:
                     LLAMA_ASSERT(false);
             }
@@ -1963,7 +1971,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
                 base_t = dest_t;
             }
 
-            if (base_t->type == GGML_TYPE_Q4_0 || base_t->type == GGML_TYPE_Q4_1 || base_t->type == GGML_TYPE_Q4_2) {
+            if (ggml_is_quantized(base_t->type)) {
                 if (!warned) {
                     fprintf(stderr, "%s: warning: using a lora adapter with a quantized model may result in poor quality, "
                                     "use a f16 or f32 base model with --lora-base\n", __func__);
diff --git a/llama.h b/llama.h
index 208b03d18056cfabd598a9e61bd1792a051f6d6a..011e34c00fa3cfcbecb4d043fac84cdb16144326 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -73,6 +73,7 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_Q4_1 = 3,  // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
         LLAMA_FTYPE_MOSTLY_Q4_2 = 5,  // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_3 = 6,  // except 1d tensors
     };
 
     LLAMA_API struct llama_context_params llama_context_default_params();