]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : generalize `quantize_fns` for simpler FP16 handling (#1237)
authorStephan Walter <redacted>
Wed, 5 Jul 2023 16:13:06 +0000 (16:13 +0000)
committerGitHub <redacted>
Wed, 5 Jul 2023 16:13:06 +0000 (19:13 +0300)
* Generalize quantize_fns for simpler FP16 handling

* Remove call to ggml_cuda_mul_mat_get_wsize

* ci : disable FMA for mac os actions

---------

Co-authored-by: Georgi Gerganov <redacted>
.github/workflows/build.yml
examples/quantize-stats/quantize-stats.cpp
ggml.c
ggml.h
llama.cpp
pocs/vdot/q8dot.cpp
pocs/vdot/vdot.cpp
tests/test-quantize-fns.cpp
tests/test-quantize-perf.cpp

index aec43bd923bb5ecaec5cc0e7ad232e4d5738e428..12481e8be7cf7ec6adff597717a5b2e755a09d2d 100644 (file)
@@ -137,9 +137,10 @@ jobs:
       - name: Build
         id: cmake_build
         run: |
+          sysctl -a
           mkdir build
           cd build
-          cmake -DLLAMA_AVX2=OFF ..
+          cmake -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF ..
           cmake --build . --config Release
 
       - name: Test
index 9cea472dedb82ca95488227acdd0181f50cb3713..6aa06ec8fa1152dc3474c071ac3657b031dbe08c 100644 (file)
@@ -147,7 +147,7 @@ void test_roundtrip_on_chunk(
         const ggml_tensor * layer,
         int64_t offset,
         int64_t chunk_size,
-        const quantize_fns_t & qfns,
+        const ggml_type_traits_t & qfns,
         bool use_reference,
         float * input_scratch,
         char * quantized_scratch,
@@ -163,11 +163,11 @@ void test_roundtrip_on_chunk(
     }
 
     if (use_reference) {
-        qfns.quantize_row_q_reference(input_scratch, quantized_scratch, chunk_size);
+        qfns.from_float_reference(input_scratch, quantized_scratch, chunk_size);
     } else {
-        qfns.quantize_row_q(input_scratch, quantized_scratch, chunk_size);
+        qfns.from_float(input_scratch, quantized_scratch, chunk_size);
     }
-    qfns.dequantize_row_q(quantized_scratch, output_scratch, chunk_size);
+    qfns.to_float(quantized_scratch, output_scratch, chunk_size);
 
     update_error_stats(chunk_size, input_scratch, output_scratch, stats);
 }
@@ -177,7 +177,7 @@ void test_roundtrip_on_chunk(
 void test_roundtrip_on_layer(
         std::string & name,
         bool print_layer_stats,
-        const quantize_fns_t & qfns,
+        const ggml_type_traits_t & qfns,
         bool use_reference,
         const ggml_tensor * layer,
         std::vector<float> & input_scratch,
@@ -388,8 +388,8 @@ int main(int argc, char ** argv) {
         if (!params.include_types.empty() && std::find(params.include_types.begin(), params.include_types.end(), i) == params.include_types.end()) {
             continue;
         }
-        quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
-        if (qfns.quantize_row_q && qfns.dequantize_row_q) {
+        ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
+        if (qfns.from_float && qfns.to_float) {
             if (params.verbose) {
                 printf("testing %s ...\n",  ggml_type_name(type));
             }
diff --git a/ggml.c b/ggml.c
index 88cbed7d5347c6e8cb07003769385860d4d450c9..635c32eb5e213bacf7b5d50823cffbeaa31d3e6b 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -481,14 +481,14 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
     return GGML_FP32_TO_FP16(x);
 }
 
-void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) {
-    for (size_t i = 0; i < n; i++) {
+void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n) {
+    for (int i = 0; i < n; i++) {
         y[i] = GGML_FP16_TO_FP32(x[i]);
     }
 }
 
-void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
-    size_t i = 0;
+void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) {
+    int i = 0;
 #if defined(__F16C__)
     for (; i + 7 < n; i += 8) {
         __m256 x_vec = _mm256_loadu_ps(x + i);
@@ -1627,109 +1627,112 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
     }
 }
 
+static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
+static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
 static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
-static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
+static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_F32] = {
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
+        .vec_dot_type             = GGML_TYPE_F32,
+    },
+    [GGML_TYPE_F16] = {
+        .to_float                 = (ggml_to_float_t) ggml_fp16_to_fp32_row,
+        .from_float               = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+        .from_float_reference     = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f16,
+        .vec_dot_type             = GGML_TYPE_F16,
+    },
     [GGML_TYPE_Q4_0] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q4_0,
-        .quantize_row_q           = quantize_row_q4_0,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
-        .quantize_row_q_dot       = quantize_row_q8_0,
-        .vec_dot_q                = ggml_vec_dot_q4_0_q8_0,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_0,
+        .from_float               = quantize_row_q4_0,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_0_reference,
+        .vec_dot                  = ggml_vec_dot_q4_0_q8_0,
         .vec_dot_type             = GGML_TYPE_Q8_0,
     },
     [GGML_TYPE_Q4_1] = {
-        .dequantize_row_q         = (dequantize_row_q_t)dequantize_row_q4_1,
-        .quantize_row_q           = quantize_row_q4_1,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
-        .quantize_row_q_dot       = quantize_row_q8_1,
-        .vec_dot_q                = ggml_vec_dot_q4_1_q8_1,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_1,
+        .from_float               = quantize_row_q4_1,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_1_reference,
+        .vec_dot                  = ggml_vec_dot_q4_1_q8_1,
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
     [GGML_TYPE_Q5_0] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q5_0,
-        .quantize_row_q           = quantize_row_q5_0,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
-        .quantize_row_q_dot       = quantize_row_q8_0,
-        .vec_dot_q                = ggml_vec_dot_q5_0_q8_0,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_0,
+        .from_float               = quantize_row_q5_0,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_0_reference,
+        .vec_dot                  = ggml_vec_dot_q5_0_q8_0,
         .vec_dot_type             = GGML_TYPE_Q8_0,
     },
     [GGML_TYPE_Q5_1] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q5_1,
-        .quantize_row_q           = quantize_row_q5_1,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
-        .quantize_row_q_dot       = quantize_row_q8_1,
-        .vec_dot_q                = ggml_vec_dot_q5_1_q8_1,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_1,
+        .from_float               = quantize_row_q5_1,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_1_reference,
+        .vec_dot                  = ggml_vec_dot_q5_1_q8_1,
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
     [GGML_TYPE_Q8_0] = {
-        .dequantize_row_q         = dequantize_row_q8_0,
-        .quantize_row_q           = quantize_row_q8_0,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
-        .quantize_row_q_dot       = quantize_row_q8_0,
-        .vec_dot_q                = ggml_vec_dot_q8_0_q8_0,
+        .to_float                 = dequantize_row_q8_0,
+        .from_float               = quantize_row_q8_0,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q8_0_reference,
+        .vec_dot                  = ggml_vec_dot_q8_0_q8_0,
         .vec_dot_type             = GGML_TYPE_Q8_0,
     },
     [GGML_TYPE_Q8_1] = {
-        .dequantize_row_q         = NULL,   // TODO
-        .quantize_row_q           = quantize_row_q8_1,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
-        .quantize_row_q_dot       = quantize_row_q8_1,
-        .vec_dot_q                = NULL,   // TODO
+        .from_float               = quantize_row_q8_1,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q8_1_reference,
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
 #ifdef GGML_USE_K_QUANTS
     [GGML_TYPE_Q2_K] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q2_K,
-        .quantize_row_q           = quantize_row_q2_K,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference,
-        .quantize_row_q_dot       = quantize_row_q8_K,
-        .vec_dot_q                = ggml_vec_dot_q2_K_q8_K,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q2_K,
+        .from_float               = quantize_row_q2_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q2_K_reference,
+        .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q3_K] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q3_K,
-        .quantize_row_q           = quantize_row_q3_K,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference,
-        .quantize_row_q_dot       = quantize_row_q8_K,
-        .vec_dot_q                = ggml_vec_dot_q3_K_q8_K,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q3_K,
+        .from_float               = quantize_row_q3_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q3_K_reference,
+        .vec_dot                  = ggml_vec_dot_q3_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q4_K] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q4_K,
-        .quantize_row_q           = quantize_row_q4_K,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference,
-        .quantize_row_q_dot       = quantize_row_q8_K,
-        .vec_dot_q                = ggml_vec_dot_q4_K_q8_K,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_K,
+        .from_float               = quantize_row_q4_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q4_K_reference,
+        .vec_dot                  = ggml_vec_dot_q4_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q5_K] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q5_K,
-        .quantize_row_q           = quantize_row_q5_K,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_K_reference,
-        .quantize_row_q_dot       = quantize_row_q8_K,
-        .vec_dot_q                = ggml_vec_dot_q5_K_q8_K,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_K,
+        .from_float               = quantize_row_q5_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q5_K_reference,
+        .vec_dot                  = ggml_vec_dot_q5_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q6_K] = {
-        .dequantize_row_q         = (dequantize_row_q_t) dequantize_row_q6_K,
-        .quantize_row_q           = quantize_row_q6_K,
-        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_K_reference,
-        .quantize_row_q_dot       = quantize_row_q8_K,
-        .vec_dot_q                = ggml_vec_dot_q6_K_q8_K,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q6_K,
+        .from_float               = quantize_row_q6_K,
+        .from_float_reference     = (ggml_from_float_t) quantize_row_q6_K_reference,
+        .vec_dot                  = ggml_vec_dot_q6_K_q8_K,
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
+    [GGML_TYPE_Q8_K] = {
+        .from_float               = quantize_row_q8_K,
+    }
 #endif
 };
 
 // For internal test use
-quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
+ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type i) {
     GGML_ASSERT(i < GGML_TYPE_COUNT);
-    return quantize_fns[i];
+    return type_traits[i];
 }
 
 
@@ -2275,7 +2278,7 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
 inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
 inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
 
-inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
+static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
 #ifdef GGML_SIMD
     float sumf = 0.0f;
     const int np = (n & ~(GGML_F32_STEP - 1));
@@ -2312,7 +2315,7 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     *s = sumf;
 }
 
-inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
+static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
     ggml_float sumf = 0.0;
 
 #if defined(GGML_SIMD)
@@ -7825,8 +7828,8 @@ static void ggml_compute_forward_dup_f16(
                         id += ne00 * (ne01 - ir1);
                     }
                 }
-            } else if (ggml_is_quantized(dst->type)) {
-                quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+            } else if (type_traits[dst->type].from_float) {
+                ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
                 float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
 
                 size_t id = 0;
@@ -8078,26 +8081,8 @@ static void ggml_compute_forward_dup_f32(
                         id += rs * (ne01 - ir1);
                     }
                 }
-            } else if (dst->type == GGML_TYPE_F16) {
-                size_t id = 0;
-                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += ne00 * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            for (int i00 = 0; i00 < ne00; i00++) {
-                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
-                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
-                                id++;
-                            }
-                        }
-                        id += ne00 * (ne01 - ir1);
-                    }
-                }
-            } else if (ggml_is_quantized(dst->type)) {
-                quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+            } else if (type_traits[dst->type].from_float) {
+                ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
 
                 size_t id = 0;
                 size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
@@ -8503,8 +8488,8 @@ static void ggml_compute_forward_add_q_f32(
     const int nth = params->nth;
 
     const enum ggml_type type = src0->type;
-    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
-    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+    ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+    ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
@@ -8777,8 +8762,8 @@ static void ggml_compute_forward_add1_q_f32(
     GGML_TENSOR_UNARY_OP_LOCALS;
 
     const enum ggml_type type = src0->type;
-    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
-    quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+    ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+    ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
 
     // we don't support permuted src0
     GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
@@ -10578,317 +10563,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 }
 #endif
 
-static void ggml_compute_forward_mul_mat_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    assert(ne02 == ne12);
-    assert(ne03 == ne13);
-    assert(ne2  == ne12);
-    assert(ne3  == ne13);
-
-    // we don't support permuted src0 or src1
-    assert(nb00 == sizeof(float));
-    assert(nb10 == sizeof(float));
-
-    // dst cannot be transposed or permuted
-    assert(nb0 == sizeof(float));
-    assert(nb0 <= nb1);
-    assert(nb1 <= nb2);
-    assert(nb2 <= nb3);
-
-    assert(ne0 == ne01);
-    assert(ne1 == ne11);
-    assert(ne2 == ne02);
-    assert(ne3 == ne03);
-
-    // nb01 >= nb00 - src0 is not transposed
-    //   compute by src0 rows
-
-#if defined(GGML_USE_CLBLAST)
-    if (ggml_cl_can_mul_mat(src0, src1, dst)) {
-        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
-        }
-        return;
-    }
-#endif
-
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        if (params->ith != 0) {
-            return;
-        }
-
-        if (params->type == GGML_TASK_INIT) {
-            return;
-        }
-
-        if (params->type == GGML_TASK_FINALIZE) {
-            return;
-        }
-
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
-                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-
-                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                        ne11, ne01, ne10,
-                        1.0f,    y, ne10,
-                                 x, ne00,
-                        0.0f,    d, ne01);
-            }
-        }
-        //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
-
-        return;
-    }
-#endif
-
-    if (params->type == GGML_TASK_INIT) {
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    // parallelize by src0 rows using ggml_vec_dot_f32
-
-    // total rows in src0
-    const int nr = ne01*ne02*ne03;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 indices
-        const int i03 = ir/(ne02*ne01);
-        const int i02 = (ir - i03*ne02*ne01)/ne01;
-        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-        for (int64_t ic = 0; ic < ne11; ++ic) {
-            // src1 indices
-            const int i13 = i03;
-            const int i12 = i02;
-            const int i11 = ic;
-
-            // dst indices
-            const int i0 = i01;
-            const int i1 = i11;
-            const int i2 = i02;
-            const int i3 = i03;
-
-            ggml_vec_dot_f32(ne00,
-                    (float *) ((char *)  dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
-                    (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
-                    (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
-        }
-    }
-
-    //int64_t t1 = ggml_perf_time_us();
-    //static int64_t acc = 0;
-    //acc += t1 - t0;
-    //if (t1 - t0 > 10) {
-    //    printf("\n");
-    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
-    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
-    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-    //    printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
-
-    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
-    //}
-}
-
-static void ggml_compute_forward_mul_mat_f16_f32(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-              struct ggml_tensor * dst) {
-    int64_t t0 = ggml_perf_time_us();
-    UNUSED(t0);
-
-    GGML_TENSOR_BINARY_OP_LOCALS;
-
-    //const int64_t ne   = ne0*ne1*ne2*ne3;
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne03 == ne13);
-    GGML_ASSERT(ne2  == ne12);
-    GGML_ASSERT(ne3  == ne13);
-
-    // TODO: we don't support permuted src0
-    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
-    // dst cannot be transposed or permuted
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(nb0 <= nb1);
-    GGML_ASSERT(nb1 <= nb2);
-    GGML_ASSERT(nb2 <= nb3);
-
-    GGML_ASSERT(ne0 == ne01);
-    GGML_ASSERT(ne1 == ne11);
-    GGML_ASSERT(ne2 == ne02);
-    GGML_ASSERT(ne3 == ne03);
-
-    // nb01 >= nb00 - src0 is not transposed
-    //   compute by src0 rows
-
-#if defined(GGML_USE_CLBLAST)
-    if (ggml_cl_can_mul_mat(src0, src1, dst)) {
-        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
-        }
-        return;
-    }
-#endif
-
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-    if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
-        GGML_ASSERT(nb10 == sizeof(float));
-
-        if (params->ith != 0) {
-            return;
-        }
-
-        if (params->type == GGML_TASK_INIT) {
-            return;
-        }
-
-        if (params->type == GGML_TASK_FINALIZE) {
-            return;
-        }
-
-        for (int64_t i03 = 0; i03 < ne03; i03++) {
-            for (int64_t i02 = 0; i02 < ne02; i02++) {
-                float * const wdata = params->wdata;
-                {
-                    size_t id = 0;
-                    for (int64_t i01 = 0; i01 < ne01; ++i01) {
-                        for (int64_t i00 = 0; i00 < ne00; ++i00) {
-                            wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
-                        }
-                    }
-
-                    assert(id*sizeof(float) <= params->wsize);
-                }
-
-                const float * x = wdata;
-                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-
-                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-
-                // zT = y * xT
-                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                        ne11, ne01, ne10,
-                        1.0f,    y, ne10,
-                                 x, ne00,
-                        0.0f,    d, ne01);
-            }
-        }
-
-        /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
-
-        return;
-    }
-#endif
-
-    if (params->type == GGML_TASK_INIT) {
-        ggml_fp16_t * const wdata = params->wdata;
-
-        size_t id = 0;
-        for (int64_t i13 = 0; i13 < ne13; ++i13) {
-            for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                    for (int64_t i10 = 0; i10 < ne10; ++i10) {
-                        wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
-                    }
-                }
-            }
-        }
-
-        GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
-
-        return;
-    }
-
-    if (params->type == GGML_TASK_FINALIZE) {
-        return;
-    }
-
-    // fp16 -> half the size, so divide by 2
-    // TODO: do not support transposed src1
-    assert(nb10/2 == sizeof(ggml_fp16_t));
-
-    // parallelize by src0 rows using ggml_vec_dot_f16
-
-    // total rows in src0
-    const int nr = ne01*ne02*ne03;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    ggml_fp16_t * wdata = params->wdata;
-
-    for (int ir = ir0; ir < ir1; ++ir) {
-        // src0 indices
-        const int i03 = ir/(ne02*ne01);
-        const int i02 = (ir - i03*ne02*ne01)/ne01;
-        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
-        const int i13 = i03;
-        const int i12 = i02;
-
-        const int i0 = i01;
-        const int i2 = i02;
-        const int i3 = i03;
-
-        ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
-        ggml_fp16_t * src1_col =                                wdata + (       0 + i12*ne11 + i13*ne12*ne11)*ne00;
-
-        float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
-
-        for (int64_t ic = 0; ic < ne11; ++ic) {
-            ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
-        }
-    }
-
-    //int64_t t1 = ggml_time_us();
-    //static int64_t acc = 0;
-    //acc += t1 - t0;
-    //if (t1 - t0 > 10) {
-    //    printf("\n");
-    //    printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
-    //    printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
-    //    printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-
-    //    printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
-    //}
-}
-
-static void ggml_compute_forward_mul_mat_q_f32(
+static void ggml_compute_forward_mul_mat(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
@@ -10907,9 +10582,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
     GGML_ASSERT(ne3  == ne13);
 
     const enum ggml_type type = src0->type;
-    quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
-    vec_dot_q_t      const vec_dot_q          = quantize_fns[type].vec_dot_q;
-    enum ggml_type   const vec_dot_type       = quantize_fns[type].vec_dot_type;
+
+    ggml_vec_dot_t    const vec_dot               = type_traits[type].vec_dot;
+    enum ggml_type    const vec_dot_type          = type_traits[type].vec_dot_type;
+    ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
@@ -10952,27 +10628,27 @@ static void ggml_compute_forward_mul_mat_q_f32(
             return;
         }
 
-        float * const wdata = params->wdata;
-        dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
-
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
+                const void * x = (char *) src0->data + i03*nb03 + i02*nb02;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
-                {
+                if (type != GGML_TYPE_F32) {
+                    float * const wdata = params->wdata;
+                    ggml_to_float_t const to_float = type_traits[type].to_float;
+
                     size_t id = 0;
                     for (int64_t i01 = 0; i01 < ne01; ++i01) {
-                        dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
+                        to_float((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
                         id += ne00;
                     }
 
                     assert(id*sizeof(float) <= params->wsize);
+                    x = wdata;
                 }
 
-                const float * x = wdata;
-
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
@@ -10988,14 +10664,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
 #endif
 
     if (params->type == GGML_TASK_INIT) {
-        char * wdata = params->wdata;
-        const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
-
-        for (int64_t i13 = 0; i13 < ne13; ++i13) {
-            for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                for (int64_t i11 = 0; i11 < ne11; ++i11) {
-                    quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
-                    wdata += row_size;
+        if (src1->type != vec_dot_type) {
+            char * wdata = params->wdata;
+            const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
+
+            for (int64_t i13 = 0; i13 < ne13; ++i13) {
+                for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                    for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                        from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+                        wdata += row_size;
+                    }
                 }
             }
         }
@@ -11019,7 +10697,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    void * wdata = params->wdata;
+    void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
     const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
 
     for (int ir = ir0; ir < ir1; ++ir) {
@@ -11043,7 +10721,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
         assert(ne00 % 32 == 0);
 
         for (int64_t ic = 0; ic < ne11; ++ic) {
-            vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
+            vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
         }
     }
 
@@ -11060,40 +10738,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
     //}
 }
 
-static void ggml_compute_forward_mul_mat(
-        const struct ggml_compute_params * params,
-        const struct ggml_tensor * src0,
-        const struct ggml_tensor * src1,
-        struct ggml_tensor * dst) {
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-            {
-                ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
-            } break;
-        default:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
 
 // ggml_compute_forward_out_prod
 
@@ -11483,7 +11127,7 @@ static void ggml_compute_forward_get_rows_q(
     const int nc = src0->ne[0];
     const int nr = ggml_nelements(src1);
     const enum ggml_type type = src0->type;
-    dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
+    ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
 
     assert( dst->ne[0] == nc);
     assert( dst->ne[1] == nr);
@@ -16529,6 +16173,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
 
                         size_t cur = 0;
+                        const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
 
 #if defined(GGML_USE_CUBLAS)
                         if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
@@ -16544,37 +16189,18 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         }
                         else
 #endif
-                        if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                node->n_tasks = 1; // TODO: this actually is doing nothing
-                                                   //       the threads are still spinning
+                        if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+                            node->n_tasks = 1; // TODO: this actually is doing nothing
+                                               //       the threads are still spinning
+                            if (node->src0->type != GGML_TYPE_F32) {
                                 // here we need memory just for single 2D matrix from src0
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                            } else {
-                                cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
-                            }
-#else
-                            cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
-#endif
-                        } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
-                            cur = 0;
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                node->n_tasks = 1;
                             }
+                        } else
 #endif
-                        } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
-                            if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
-                                node->n_tasks = 1;
-                                cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
-                            } else
-#endif
-                            {
-                                const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
-                                cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
-                            }
+                        if (node->src1->type != vec_dot_type) {
+                            cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
                         } else {
                             GGML_ASSERT(false);
                         }
diff --git a/ggml.h b/ggml.h
index 0af96c76b6d0ecd4951c6dc6b8b5e3959bfc3d0b..24ca8ae221c75fe453a58b4c016042921bf6dacd 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -250,8 +250,8 @@ extern "C" {
     GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t x);
     GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
 
-    GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n);
-    GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n);
+    GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n);
+    GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n);
 
     struct ggml_object;
     struct ggml_context;
@@ -1514,26 +1514,19 @@ extern "C" {
     // Internal types and functions exposed for tests and benchmarks
     //
 
-#ifdef  __cplusplus
-    // restrict not standard in C++
-#define GGML_RESTRICT
-#else
-#define GGML_RESTRICT restrict
-#endif
-    typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-    typedef void (*quantize_row_q_t)  (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-    typedef void (*vec_dot_q_t)       (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
+    typedef void (*ggml_to_float_t)(const void * x, float * y, int k);
+    typedef void (*ggml_from_float_t)(const float * x, void * y, int k);
+    typedef void (*ggml_vec_dot_t)(const int n, float * s, const void * x, const void * y);
 
     typedef struct {
-        dequantize_row_q_t dequantize_row_q;
-        quantize_row_q_t   quantize_row_q;
-        quantize_row_q_t   quantize_row_q_reference;
-        quantize_row_q_t   quantize_row_q_dot;
-        vec_dot_q_t        vec_dot_q;
-        enum ggml_type     vec_dot_type;
-    } quantize_fns_t;
-
-    quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
+        ggml_to_float_t   to_float;
+        ggml_from_float_t from_float;
+        ggml_from_float_t from_float_reference;
+        ggml_vec_dot_t    vec_dot;
+        enum ggml_type    vec_dot_type;
+    } ggml_type_traits_t;
+
+    ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type i);
 
 #ifdef  __cplusplus
 }
index e04fbfc0a04de29b0b112e4bed17d2f176fc17f4..7a866cb79106e025887f1cbfe4f5af0510849877 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2257,10 +2257,10 @@ static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llam
     }
     float * f32_output = (float *) output.addr;
 
-    quantize_fns_t qtype;
+    ggml_type_traits_t qtype;
     if (ggml_is_quantized(tensor.type)) {
-        qtype = ggml_internal_get_quantize_fn(tensor.type);
-        if (qtype.dequantize_row_q == NULL) {
+        qtype = ggml_internal_get_type_traits(tensor.type);
+        if (qtype.to_float == NULL) {
             throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor.type)));
         }
     } else if (tensor.type != GGML_TYPE_F16) {
@@ -2271,7 +2271,7 @@ static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llam
         if (tensor.type == GGML_TYPE_F16) {
             ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor.data, f32_output, nelements);
         } else if (ggml_is_quantized(tensor.type)) {
-            qtype.dequantize_row_q(tensor.data, f32_output, nelements);
+            qtype.to_float(tensor.data, f32_output, nelements);
         } else {
             LLAMA_ASSERT(false); // unreachable
         }
@@ -2296,7 +2296,7 @@ static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llam
             if (typ == GGML_TYPE_F16) {
                 ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
             } else {
-                qtype.dequantize_row_q(inbuf, outbuf, nels);
+                qtype.to_float(inbuf, outbuf, nels);
             }
         };
         workers.push_back(std::thread(compute, tensor.type, tensor.data + in_buff_offs, f32_output + out_buff_offs, thr_elems));
index 5748c8ac22193cd2e1ddb90c1fe43291541847ee..4e0e023575322fda13c7d4a359fd3093c0534188 100644 (file)
@@ -136,7 +136,7 @@ int main(int argc, char** argv) {
 
     auto ggml_type = type == 0 ? GGML_TYPE_Q4_0 : GGML_TYPE_Q4_1;
 
-    auto funcs = ggml_internal_get_quantize_fn(ggml_type);
+    auto funcs = ggml_internal_get_type_traits(ggml_type);
 
     Stat simple, ggml;
 
@@ -156,8 +156,8 @@ int main(int argc, char** argv) {
 
         t1 = std::chrono::high_resolution_clock::now();
         float fs;
-        if (type == 0) funcs.vec_dot_q(kVecSize * QK4_1, &fs, x40.data(), y.data());
-        else funcs.vec_dot_q(kVecSize * QK4_1, &fs, x41.data(), y.data());
+        if (type == 0) funcs.vec_dot(kVecSize * QK4_1, &fs, x40.data(), y.data());
+        else funcs.vec_dot(kVecSize * QK4_1, &fs, x41.data(), y.data());
         t2 = std::chrono::high_resolution_clock::now();
         t = 1e-3*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
         if (iloop > 3) ggml.addResult(fs, t);
index 7b18090d66b6b650a28a8ace27845c2dc2bcf4fe..48758cda81fdfd7207cb7fa0fb787e610a60ebed 100644 (file)
@@ -235,7 +235,7 @@ int main(int argc, char** argv) {
     int n4 = useQ4_1 ? kVecSize / QK4_1 : kVecSize / QK4_0; n4 = 64*((n4 + 63)/64);
     int n8 = kVecSize / QK8_0; n8 = 64*((n8 + 63)/64);
 
-    auto funcs = useQ4_1 ? ggml_internal_get_quantize_fn(GGML_TYPE_Q4_1) : ggml_internal_get_quantize_fn(GGML_TYPE_Q4_0);
+    auto funcs = useQ4_1 ? ggml_internal_get_type_traits(GGML_TYPE_Q4_1) : ggml_internal_get_type_traits(GGML_TYPE_Q4_0);
 
     std::vector<block_q4_0> q40;
     std::vector<block_q4_1> q41;
@@ -261,9 +261,9 @@ int main(int argc, char** argv) {
         // Note, we do not include this in the timing as in practical application
         // we already have the quantized model weights.
         if (useQ4_1) {
-            funcs.quantize_row_q(x1.data(), q41.data(), kVecSize);
+            funcs.from_float(x1.data(), q41.data(), kVecSize);
         } else {
-            funcs.quantize_row_q(x1.data(), q40.data(), kVecSize);
+            funcs.from_float(x1.data(), q40.data(), kVecSize);
         }
 
         // Now measure time the dot product needs using the "scalar" version above
@@ -282,9 +282,10 @@ int main(int argc, char** argv) {
             dot_q4_q8(kVecSize, &result, q40.data(), q8.data());
         }
         else {
-            funcs.quantize_row_q_dot(y1.data(), q8.data(), kVecSize);
-            if (useQ4_1) funcs.vec_dot_q(kVecSize, &result, q41.data(), q8.data());
-            else funcs.vec_dot_q(kVecSize, &result, q40.data(), q8.data());
+            auto vdot = ggml_internal_get_type_traits(funcs.vec_dot_type);
+            vdot.from_float(y1.data(), q8.data(), kVecSize);
+            if (useQ4_1) funcs.vec_dot(kVecSize, &result, q41.data(), q8.data());
+            else funcs.vec_dot(kVecSize, &result, q40.data(), q8.data());
         }
         sumq += result;
         t2 = std::chrono::high_resolution_clock::now();
index c40f1b29c7c368cb0a3ae350f776882243ff040c..8d3c162d2bfa04bdb29a40131368dd9efda7523d 100644 (file)
@@ -40,26 +40,26 @@ float array_rmse(const float * a1, const float * a2, size_t n) {
 }
 
 // Total quantization error on test data
-float total_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) {
+float total_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
     std::vector<uint8_t> tmp_q(2*test_size);
     std::vector<float> tmp_out(test_size);
 
-    qfns.quantize_row_q(test_data, tmp_q.data(), test_size);
-    qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size);
+    qfns.from_float(test_data, tmp_q.data(), test_size);
+    qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
     return array_rmse(test_data, tmp_out.data(), test_size);
 }
 
 // Total quantization error on test data
-float reference_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) {
+float reference_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
     std::vector<uint8_t> tmp_q(2*test_size);
     std::vector<float> tmp_out(test_size);
     std::vector<float> tmp_out_ref(test_size);
 
-    qfns.quantize_row_q(test_data, tmp_q.data(), test_size);
-    qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size);
+    qfns.from_float(test_data, tmp_q.data(), test_size);
+    qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
 
-    qfns.quantize_row_q_reference(test_data, tmp_q.data(), test_size);
-    qfns.dequantize_row_q(tmp_q.data(), tmp_out_ref.data(), test_size);
+    qfns.from_float_reference(test_data, tmp_q.data(), test_size);
+    qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
 
     return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
 }
@@ -73,15 +73,17 @@ float dot_product(const float * a1, const float * a2, size_t test_size) {
 }
 
 // Total dot product error
-float dot_product_error(quantize_fns_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) {
+float dot_product_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) {
     std::vector<uint8_t> tmp_q1(2*test_size);
     std::vector<uint8_t> tmp_q2(2*test_size);
 
-    qfns.quantize_row_q    (test_data1, tmp_q1.data(), test_size);
-    qfns.quantize_row_q_dot(test_data2, tmp_q2.data(), test_size);
+    auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
+
+    qfns.from_float(test_data1, tmp_q1.data(), test_size);
+    vdot.from_float(test_data2, tmp_q2.data(), test_size);
 
     float result = INFINITY;
-    qfns.vec_dot_q(test_size, &result, tmp_q1.data(), tmp_q2.data());
+    qfns.vec_dot(test_size, &result, tmp_q1.data(), tmp_q2.data());
 
     const float dot_ref = dot_product(test_data1, test_data2, test_size);
 
@@ -123,9 +125,9 @@ int main(int argc, char * argv[]) {
 
     for (int i = 0; i < GGML_TYPE_COUNT; i++) {
         ggml_type type = (ggml_type) i;
-        quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
+        ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
 
-        if (qfns.quantize_row_q && qfns.dequantize_row_q) {
+        if (qfns.from_float && qfns.to_float) {
             const float total_error = total_quantization_error(qfns, test_size, test_data.data());
             const float max_quantization_error =
                 type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
index c0e361e92313f3da5a2ce808e255089e85d867ca..0bb9537f693ed22ba58bb642d7f124cd2373e8c8 100644 (file)
@@ -123,9 +123,9 @@ void usage(char * argv[]) {
     printf("  --type TYPE           set test type as");
     for (int i = 0; i < GGML_TYPE_COUNT; i++) {
         ggml_type type = (ggml_type) i;
-        quantize_fns_t qfns = ggml_internal_get_quantize_fn(type);
+        ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
         if (ggml_type_name(type) != NULL) {
-            if (qfns.quantize_row_q && qfns.dequantize_row_q) {
+            if (qfns.from_float && qfns.to_float) {
                 printf(" %s", ggml_type_name(type));
             }
         }
@@ -271,12 +271,12 @@ int main(int argc, char * argv[]) {
 
     for (int i = 0; i < GGML_TYPE_COUNT; i++) {
         ggml_type type = (ggml_type) i;
-        quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
+        ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
         if (!params.include_types.empty() && ggml_type_name(type) && std::find(params.include_types.begin(), params.include_types.end(), ggml_type_name(type)) == params.include_types.end()) {
             continue;
         }
 
-        if (qfns.quantize_row_q && qfns.dequantize_row_q) {
+        if (qfns.from_float && qfns.to_float) {
             printf("%s\n", ggml_type_name(type));
 
             if (params.op_quantize_row_q_reference) {
@@ -284,7 +284,7 @@ int main(int argc, char * argv[]) {
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
                     auto quantize_fn = [&](void ) {
-                        qfns.quantize_row_q_reference(test_data1, test_q1, size);
+                        qfns.from_float_reference(test_data1, test_q1, size);
                         return test_q1[0];
                     };
                     size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -298,7 +298,7 @@ int main(int argc, char * argv[]) {
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
                     auto quantize_fn = [&](void ) {
-                        qfns.quantize_row_q(test_data1, test_q1, size);
+                        qfns.from_float(test_data1, test_q1, size);
                         return test_q1[0];
                     };
                     size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -309,11 +309,11 @@ int main(int argc, char * argv[]) {
 
             if (params.op_dequantize_row_q) {
                 printf("  dequantize_row_q\n");
-                qfns.quantize_row_q(test_data1, test_q1, largest);
+                qfns.from_float(test_data1, test_q1, largest);
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
                     auto quantize_fn = [&](void ) {
-                        qfns.dequantize_row_q(test_q1, test_out, size);
+                        qfns.to_float(test_q1, test_out, size);
                         return test_out[0];
                     };
                     size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -327,7 +327,8 @@ int main(int argc, char * argv[]) {
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
                     auto quantize_fn = [&](void ) {
-                        qfns.quantize_row_q_dot(test_data1, test_q1, size);
+                        auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
+                        vdot.from_float(test_data1, test_q1, size);
                         return test_q1[0];
                     };
                     size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);
@@ -338,13 +339,13 @@ int main(int argc, char * argv[]) {
 
             if (params.op_vec_dot_q) {
                 printf("  vec_dot_q\n");
-                qfns.quantize_row_q(test_data1, test_q1, largest);
-                qfns.quantize_row_q(test_data2, test_q2, largest);
+                qfns.from_float(test_data1, test_q1, largest);
+                qfns.from_float(test_data2, test_q2, largest);
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
                     auto quantize_fn = [&](void ) {
                         float result;
-                        qfns.vec_dot_q(size, &result, test_q1, test_q2);
+                        qfns.vec_dot(size, &result, test_q1, test_q2);
                         return result;
                     };
                     size_t quantized_size = size / ggml_blck_size(type) * ggml_type_size(type);