]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : use F16 instead of F32 in Q4_0, Q4_1, Q8_0 (#1508)
authorGeorgi Gerganov <redacted>
Fri, 19 May 2023 19:17:18 +0000 (22:17 +0300)
committerGitHub <redacted>
Fri, 19 May 2023 19:17:18 +0000 (22:17 +0300)
* ggml : use F16 instead of F32 in Q4_0, Q4_1 and Q8_0

* llama : bump LLAMA_FILE_VERSION to 3

* cuda : update Q4 and Q8 dequantize kernels

* ggml : fix AVX dot products

* readme : update performance table + hot topics

README.md
ggml-cuda.cu
ggml.c
ggml.h
llama.cpp
llama.h

index 6a67765aa5dafd2aa7bdcc41014710ccf0660264..762f4aa0349e564fe9843e0955fc55b47242e97a 100644 (file)
--- a/README.md
+++ b/README.md
@@ -9,6 +9,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
 
 **Hot topics:**
 
+- Quantization formats `Q4` and `Q8` have changed again (19 May) - [(info)](https://github.com/ggerganov/llama.cpp/pull/1508)
 - Quantization formats `Q4` and `Q5` have changed - requantize any old models [(info)](https://github.com/ggerganov/llama.cpp/pull/1405)
 - [Roadmap May 2023](https://github.com/ggerganov/llama.cpp/discussions/1220)
 
@@ -334,16 +335,16 @@ Several quantization methods are supported. They differ in the resulting model d
 
 | Model | Measure      | F16    | Q4_0   | Q4_1   | Q5_0   | Q5_1   | Q8_0   |
 |------:|--------------|-------:|-------:|-------:|-------:|-------:|-------:|
-|    7B | perplexity   | 5.9066 | 6.1565 | 6.0910 | 5.9862 | 5.9481 | 5.9069 |
-|    7B | file size    |  13.0G |   4.0G |   4.8G |   4.4G |   4.8G |   7.1G |
-|    7B | ms/tok @ 4th |    128 |     50 |     54 |     75 |     83 |     75 |
-|    7B | ms/tok @ 8th |    123 |     44 |     52 |     53 |     58 |     72 |
-|    7B | bits/weight  |   16.0 |    5.0 |    6.0 |    5.5 |    6.0 |    9.0 |
-|   13B | perplexity   | 5.2543 | 5.3860 | 5.3607 | 5.2856 | 5.2706 | 5.2548 |
-|   13B | file size    |  25.0G |   7.6G |   9.1G |   8.4G |   9.1G |    14G |
-|   13B | ms/tok @ 4th |    239 |     93 |    101 |    150 |    164 |    141 |
-|   13B | ms/tok @ 8th |    240 |     81 |     96 |     96 |    104 |    136 |
-|   13B | bits/weight  |   16.0 |    5.0 |    6.0 |    5.5 |    6.0 |    9.0 |
+|    7B | perplexity   | 5.9066 | 6.1565 | 6.0912 | 5.9862 | 5.9481 | 5.9070 |
+|    7B | file size    |  13.0G |   3.5G |   3.9G |   4.3G |   4.7G |   6.7G |
+|    7B | ms/tok @ 4th |    127 |     55 |     54 |     76 |     83 |     72 |
+|    7B | ms/tok @ 8th |    122 |     43 |     45 |     52 |     56 |     67 |
+|    7B | bits/weight  |   16.0 |    4.5 |    5.0 |    5.5 |    6.0 |    8.5 |
+|   13B | perplexity   | 5.2543 | 5.3860 | 5.3608 | 5.2856 | 5.2706 | 5.2548 |
+|   13B | file size    |  25.0G |   6.8G |   7.6G |   8.3G |   9.1G |    13G |
+|   13B | ms/tok @ 4th |      - |    103 |    105 |    148 |    160 |    131 |
+|   13B | ms/tok @ 8th |      - |     73 |     82 |     98 |    105 |    128 |
+|   13B | bits/weight  |   16.0 |    4.5 |    5.0 |    5.5 |    6.0 |    8.5 |
 
 ### Perplexity (measuring model quality)
 
index f2630ec8ee1a16b822eab278bff9c49db7bfb930..688bcf79928d520ca37a4fa71a3c5f8a9b08dd30 100644 (file)
@@ -42,19 +42,19 @@ typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y,
 #define QK4_0 32
 #define QR4_0 2
 typedef struct {
-    float   d;              // delta
+    half    d;              // delta
     uint8_t qs[QK4_0 / 2];  // nibbles / quants
 } block_q4_0;
-static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
 
 #define QK4_1 32
 #define QR4_1 2
 typedef struct {
-    float   d;              // delta
-    float   m;              // min
+    half    d;              // delta
+    half    m;              // min
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
-static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
+static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
 
 #define QK5_0 32
 #define QR5_0 2
@@ -78,10 +78,10 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
 #define QK8_0 32
 #define QR8_0 1
 typedef struct {
-    float   d;              // delta
+    half    d;              // delta
     int8_t  qs[QK8_0];      // quants
 } block_q8_0;
-static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
 
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 #define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
diff --git a/ggml.c b/ggml.c
index dbef99312a56d6ec4cf5f1c7324b6cd1bee6f78e..1cb89636a0b21e0e886d59bf13d28973a19d91e6 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -769,18 +769,18 @@ int32x4_t vcvtnq_s32_f32(float32x4_t v) {
 
 #define QK4_0 32
 typedef struct {
-    float   d;          // delta
+    ggml_fp16_t d;          // delta
     uint8_t qs[QK4_0 / 2];  // nibbles / quants
 } block_q4_0;
-static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
 
 #define QK4_1 32
 typedef struct {
-    float   d;          // delta
-    float   m;          // min
+    ggml_fp16_t d;          // delta
+    ggml_fp16_t m;          // min
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
-static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
+static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
 
 #define QK5_0 32
 typedef struct {
@@ -801,16 +801,16 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
 
 #define QK8_0 32
 typedef struct {
-    float   d;          // delta
-    int8_t  qs[QK8_0];  // quants
+    ggml_fp16_t d;         // delta
+    int8_t  qs[QK8_0];     // quants
 } block_q8_0;
-static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
 
 #define QK8_1 32
 typedef struct {
-    float   d;         // delta
-    float   s;         // d * sum(qs[i])
-    int8_t  qs[QK8_1]; // quants
+    float d;               // delta
+    float s;               // d * sum(qs[i])
+    int8_t  qs[QK8_1];     // quants
 } block_q8_1;
 static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
 
@@ -837,7 +837,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
         const float d  = max / -8;
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
 
         for (int j = 0; j < qk/2; ++j) {
             const float x0 = x[i*qk + 0    + j]*id;
@@ -877,8 +877,8 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
         const float d  = (max - min) / ((1 << 4) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
-        y[i].m = min;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
 
         for (int j = 0; j < qk/2; ++j) {
             const float x0 = (x[i*qk + 0    + j] - min)*id;
@@ -1009,7 +1009,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
         const float d = amax / ((1 << 7) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
 
         for (int j = 0; j < QK8_0; ++j) {
             const float x0 = x[i*QK8_0 + j]*id;
@@ -1044,7 +1044,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
         const float d = amax / ((1 << 7) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
 
         for (int j = 0; j < 8; j++) {
             const float32x4_t v  = vmulq_n_f32(srcv[j], id);
@@ -1079,7 +1079,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
 
         // Quantize these floats
         const float d = maxScalar / 127.f;
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
         const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
         const __m256 mul = _mm256_set1_ps( id );
 
@@ -1178,7 +1178,7 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
             sum += y[i].qs[QK8_1/2 + j];
         }
 
-        y[i].s = d * sum;
+        y[i].s = sum*d;
     }
 }
 
@@ -1330,7 +1330,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
     const int nb = k / qk;
 
     for (int i = 0; i < nb; i++) {
-        const float d = x[i].d;
+        const float d = GGML_FP16_TO_FP32(x[i].d);
 
         for (int j = 0; j < qk/2; ++j) {
             const int x0 = (x[i].qs[j] & 0x0F) - 8;
@@ -1350,8 +1350,8 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
     const int nb = k / qk;
 
     for (int i = 0; i < nb; i++) {
-        const float d = x[i].d;
-        const float m = x[i].m;
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
 
         for (int j = 0; j < qk/2; ++j) {
             const int x0 = (x[i].qs[j] & 0x0F);
@@ -1426,7 +1426,7 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
     const block_q8_0 * restrict x = vx;
 
     for (int i = 0; i < nb; i++) {
-        const float d = x[i].d;
+        const float d = GGML_FP16_TO_FP32(x[i].d);
 
         for (int j = 0; j < qk; ++j) {
             y[i*qk + j] = x[i].qs[j]*d;
@@ -1690,8 +1690,9 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
 static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
     float tmp[8];
 
-    for (int i = 0; i < 8; i++)
+    for (int i = 0; i < 8; i++) {
         tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
 
     return _mm256_loadu_ps(tmp);
 }
@@ -2111,8 +2112,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
-        const uint8x16_t m4b   = vdupq_n_u8(0x0F);
-        const int8x16_t  s8b   = vdupq_n_s8(0x8);
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
 
         const uint8x16_t v0_0 = vld1q_u8(x0->qs);
         const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2140,8 +2141,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
         const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
@@ -2158,8 +2159,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #endif
     }
 
@@ -2171,7 +2172,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; ++i) {
         /* Compute combined scale for the block */
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
 
         __m256i bx = bytes_from_nibbles_32(x[i].qs);
 
@@ -2195,7 +2196,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; ++i) {
         // Compute combined scale for the block
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
 
         const __m128i lowMask = _mm_set1_epi8(0xF);
         const __m128i off = _mm_set1_epi8(8);
@@ -2237,7 +2238,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 0 and 1
-        const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) );
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
 
         const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
 
@@ -2255,7 +2256,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 2 and 3
-        const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) );
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
 
         const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
 
@@ -2288,7 +2289,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 0 and 1
-        const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
 
         const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
 
@@ -2306,7 +2307,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 2 and 3
-        const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) );
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
 
         const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
 
@@ -2354,7 +2355,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
             sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
         }
 
-        sumf += (x[i].d*y[i].d)*sumi;
+        sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
     }
 
     *s = sumf;
@@ -2384,7 +2385,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const block_q8_1 * restrict y0 = &y[i + 0];
         const block_q8_1 * restrict y1 = &y[i + 1];
 
-        summs += x0->m * y0->s + x1->m * y1->s;
+        summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
 
         const uint8x16_t m4b = vdupq_n_u8(0x0F);
 
@@ -2408,8 +2409,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
         const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
@@ -2426,8 +2427,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #endif
     }
 
@@ -2440,13 +2441,13 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
 
     // Main loop
     for (int i = 0; i < nb; ++i) {
-        const float * d0 = &x[i].d;
-        const float * d1 = &y[i].d;
+        const float d0 = GGML_FP16_TO_FP32(x[i].d);
+        const float d1 = y[i].d;
 
-        summs += x[i].m * y[i].s;
+        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
 
-        const __m256 d0v = _mm256_broadcast_ss( d0 );
-        const __m256 d1v = _mm256_broadcast_ss( d1 );
+        const __m256 d0v = _mm256_set1_ps( d0 );
+        const __m256 d1v = _mm256_set1_ps( d1 );
 
         // Compute combined scales
         const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
@@ -2480,7 +2481,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
             sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
         }
 
-        sumf += (x[i].d*y[i].d)*sumi + x[i].m*y[i].s;
+        sumf += (GGML_FP16_TO_FP32(x[i]).d*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
     }
 
     *s = sumf;
@@ -2556,16 +2557,13 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
         const int8x16_t v1_1l = vld1q_s8(y1->qs);
         const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 
-        const float x0d = GGML_FP16_TO_FP32(x0->d);
-        const float x1d = GGML_FP16_TO_FP32(x1->d);
-
 #if defined(__ARM_FEATURE_DOTPROD)
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
-                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
-                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
@@ -2582,8 +2580,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #endif
     }
 
@@ -2658,7 +2656,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; i++) {
         /* Compute combined scale for the block */
-        const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
 
         __m256i bx = bytes_from_nibbles_32(x[i].qs);
         __m256i bxhi = bytes_from_bits_32(x[i].qh);
@@ -2682,7 +2680,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; i++) {
         /* Compute combined scale for the block */
-        const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
 
         __m256i bx = bytes_from_nibbles_32(x[i].qs);
         const __m256i bxhi = bytes_from_bits_32(x[i].qh);
@@ -2725,7 +2723,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
             sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
         }
 
-        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi;
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
     }
 
     *s = sumf;
@@ -2807,16 +2805,13 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const int8x16_t v1_1l = vld1q_s8(y1->qs);
         const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 
-        const float x0d = GGML_FP16_TO_FP32(x0->d);
-        const float x1d = GGML_FP16_TO_FP32(x1->d);
-
 #if defined(__ARM_FEATURE_DOTPROD)
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
-                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
-                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
@@ -2833,8 +2828,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #endif
     }
 
@@ -2894,15 +2889,14 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
         const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
 
-        const float x0d = GGML_FP16_TO_FP32(x0->d);
-
         // dot product
-        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
-                        wasm_i32x4_add(
-                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
-                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
-                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
-                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
+        sumv = wasm_f32x4_add(sumv,
+                wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
+                wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                               wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                               wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d));
     }
 
     *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
@@ -2924,7 +2918,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
         bx = _mm256_or_si256(bx, bxhi);
 
-        const __m256 dy = _mm256_broadcast_ss(&y[i].d);
+        const __m256 dy = _mm256_set1_ps(y[i].d);
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
         const __m256 q = mul_sum_us8_pairs_float(bx, by);
@@ -2958,7 +2952,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         bxh = _mm_or_si128(bxh, bxhih);
         bx = _mm256_set_m128i(bxh, bxl);
 
-        const __m256 dy = _mm256_broadcast_ss(&y[i].d);
+        const __m256 dy = _mm256_set1_ps(y[i].d);
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
         const __m256 q = mul_sum_us8_pairs_float(bx, by);
@@ -3028,11 +3022,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
 #if defined(__ARM_FEATURE_DOTPROD)
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
-                        vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
+                        vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
 
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
-                        vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
+                        vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 
 #else
         const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
@@ -3050,8 +3044,8 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
         const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #endif
     }
 
@@ -3063,7 +3057,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; ++i) {
         // Compute combined scale for the block
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
         __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
         __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
@@ -3089,7 +3083,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
             sumi += x[i].qs[j]*y[i].qs[j];
         }
 
-        sumf += (x[i].d*y[i].d)*sumi;
+        sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
     }
 
     *s = sumf;
diff --git a/ggml.h b/ggml.h
index 255541d0257e3abea7461e6c8a84a4cb5f1fe583..dce5ca1e7cb61957e0e3f9d54f490862bd28d7ee 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_FILE_MAGIC   0x67676d6c // "ggml"
 #define GGML_FILE_VERSION 1
 
-#define GGML_QNT_VERSION        1    // bump this on quantization format changes
+#define GGML_QNT_VERSION        2    // bump this on quantization format changes
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 
 #define GGML_MAX_DIMS          4
index 1802d2319633e678e231c47cfc667a3334333c16..6ebe85d0fed25033c69382ca30da36825a35d0ee 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -406,6 +406,7 @@ enum llama_file_version {
     LLAMA_FILE_VERSION_GGMF_V1, // added version field and scores in vocab
     LLAMA_FILE_VERSION_GGJT_V1, // added padding
     LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format
+    LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format
 };
 
 struct llama_file_loader {
@@ -438,6 +439,8 @@ struct llama_file_loader {
             file_version = LLAMA_FILE_VERSION_GGJT_V1;
         } else if (magic == 'ggjt' && version == 2) {
             file_version = LLAMA_FILE_VERSION_GGJT_V2;
+        } else if (magic == 'ggjt' && version == 3) {
+            file_version = LLAMA_FILE_VERSION_GGJT_V3;
         } else {
             throw format("unknown (magic, version) combination: %08x, %08x; is this really a GGML file?",
                          magic, version);
@@ -844,7 +847,8 @@ static const char *llama_file_version_name(llama_file_version version) {
         case LLAMA_FILE_VERSION_GGML: return "'ggml' (old version with low tokenizer quality and no mmap support)";
         case LLAMA_FILE_VERSION_GGMF_V1: return "ggmf v1 (old version with no mmap support)";
         case LLAMA_FILE_VERSION_GGJT_V1: return "ggjt v1 (pre #1405)";
-        case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (latest)";
+        case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (pre #1508)";
+        case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (latest)";
     }
 
     return "unknown";
@@ -924,11 +928,19 @@ static void llama_model_load_internal(
         fprintf(stderr, "%s: model size = %s\n",  __func__, llama_model_type_name(model.type));
     }
 
-    if (file_version != LLAMA_FILE_VERSION_GGJT_V2) {
+    if (file_version < LLAMA_FILE_VERSION_GGJT_V2) {
         if (hparams.ftype != LLAMA_FTYPE_ALL_F32     &&
             hparams.ftype != LLAMA_FTYPE_MOSTLY_F16  &&
             hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) {
-            throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1305)");
+            throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)");
+        }
+    }
+
+    if (file_version < LLAMA_FILE_VERSION_GGJT_V3) {
+        if (hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
+            hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 ||
+            hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) {
+            throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)");
         }
     }
 
diff --git a/llama.h b/llama.h
index f955fa23db048ffd5dbfc5e8e03faf5dba192bfa..fd3f21e5fa30a972f651ea98a33eff0694f2ee13 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -19,7 +19,7 @@
 #    define LLAMA_API
 #endif
 
-#define LLAMA_FILE_VERSION           2
+#define LLAMA_FILE_VERSION           3
 #define LLAMA_FILE_MAGIC             'ggjt'
 #define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
 #define LLAMA_SESSION_MAGIC          'ggsn'