]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : move all type info to ggml_type_traits (#2663)
authorslaren <redacted>
Sun, 20 Aug 2023 20:17:53 +0000 (22:17 +0200)
committerGitHub <redacted>
Sun, 20 Aug 2023 20:17:53 +0000 (22:17 +0200)
ggml.c
ggml.h

diff --git a/ggml.c b/ggml.c
index beb7f464167d53b03a98e25f49ff97c666599e41..44c43b42409a9e4c4e236dee7677c8d8496a007b 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1643,11 +1643,37 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
 static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
 static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_I8] = {
+        .type_name                = "i8",
+        .blck_size                = 1,
+        .type_size                = sizeof(int8_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_I16] = {
+        .type_name                = "i16",
+        .blck_size                = 1,
+        .type_size                = sizeof(int16_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_I32] = {
+        .type_name                = "i32",
+        .blck_size                = 1,
+        .type_size                = sizeof(int32_t),
+        .is_quantized             = false,
+    },
     [GGML_TYPE_F32] = {
+        .type_name                = "f32",
+        .blck_size                = 1,
+        .type_size                = sizeof(float),
+        .is_quantized             = false,
         .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
         .vec_dot_type             = GGML_TYPE_F32,
     },
     [GGML_TYPE_F16] = {
+        .type_name                = "f16",
+        .blck_size                = 1,
+        .type_size                = sizeof(ggml_fp16_t),
+        .is_quantized             = false,
         .to_float                 = (ggml_to_float_t) ggml_fp16_to_fp32_row,
         .from_float               = (ggml_from_float_t) ggml_fp32_to_fp16_row,
         .from_float_reference     = (ggml_from_float_t) ggml_fp32_to_fp16_row,
@@ -1655,6 +1681,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_F16,
     },
     [GGML_TYPE_Q4_0] = {
+        .type_name                = "q4_0",
+        .blck_size                = QK4_0,
+        .type_size                = sizeof(block_q4_0),
+        .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_q4_0,
         .from_float               = quantize_row_q4_0,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q4_0_reference,
@@ -1662,6 +1692,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_0,
     },
     [GGML_TYPE_Q4_1] = {
+        .type_name                = "q4_1",
+        .blck_size                = QK4_1,
+        .type_size                = sizeof(block_q4_1),
+        .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_q4_1,
         .from_float               = quantize_row_q4_1,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q4_1_reference,
@@ -1669,6 +1703,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
     [GGML_TYPE_Q5_0] = {
+        .type_name                = "q5_0",
+        .blck_size                = QK5_0,
+        .type_size                = sizeof(block_q5_0),
+        .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_q5_0,
         .from_float               = quantize_row_q5_0,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q5_0_reference,
@@ -1676,6 +1714,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_0,
     },
     [GGML_TYPE_Q5_1] = {
+        .type_name                = "q5_1",
+        .blck_size                = QK5_1,
+        .type_size                = sizeof(block_q5_1),
+        .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_q5_1,
         .from_float               = quantize_row_q5_1,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q5_1_reference,
@@ -1683,6 +1725,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
     [GGML_TYPE_Q8_0] = {
+        .type_name                = "q8_0",
+        .blck_size                = QK8_0,
+        .type_size                = sizeof(block_q8_0),
+        .is_quantized             = true,
         .to_float                 = dequantize_row_q8_0,
         .from_float               = quantize_row_q8_0,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q8_0_reference,
@@ -1690,12 +1736,20 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_0,
     },
     [GGML_TYPE_Q8_1] = {
+        .type_name                = "q8_1",
+        .blck_size                = QK8_1,
+        .type_size                = sizeof(block_q8_1),
+        .is_quantized             = true,
         .from_float               = quantize_row_q8_1,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q8_1_reference,
         .vec_dot_type             = GGML_TYPE_Q8_1,
     },
 #ifdef GGML_USE_K_QUANTS
     [GGML_TYPE_Q2_K] = {
+        .type_name                = "q2_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q2_K),
+        .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_q2_K,
         .from_float               = quantize_row_q2_K,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q2_K_reference,
@@ -1703,6 +1757,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q3_K] = {
+        .type_name                = "q3_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q3_K),
+        .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_q3_K,
         .from_float               = quantize_row_q3_K,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q3_K_reference,
@@ -1710,6 +1768,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q4_K] = {
+        .type_name                = "q4_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q4_K),
+        .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_q4_K,
         .from_float               = quantize_row_q4_K,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q4_K_reference,
@@ -1717,6 +1779,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q5_K] = {
+        .type_name                = "q5_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q5_K),
+        .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_q5_K,
         .from_float               = quantize_row_q5_K,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q5_K_reference,
@@ -1724,6 +1790,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q6_K] = {
+        .type_name                = "q6_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q6_K),
+        .is_quantized             = true,
         .to_float                 = (ggml_to_float_t) dequantize_row_q6_K,
         .from_float               = quantize_row_q6_K,
         .from_float_reference     = (ggml_from_float_t) quantize_row_q6_K_reference,
@@ -1731,15 +1801,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
     },
     [GGML_TYPE_Q8_K] = {
+        .type_name                = "q8_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q8_K),
+        .is_quantized             = true,
         .from_float               = quantize_row_q8_K,
     }
 #endif
 };
 
 // For internal test use
-ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type i) {
-    GGML_ASSERT(i < GGML_TYPE_COUNT);
-    return type_traits[i];
+ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
+    GGML_ASSERT(type < GGML_TYPE_COUNT);
+    return type_traits[type];
 }
 
 
@@ -3648,99 +3722,6 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
     *s = idx;
 }
 
-//
-// data types
-//
-
-static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
-    [GGML_TYPE_F32]  = 1,
-    [GGML_TYPE_F16]  = 1,
-    [GGML_TYPE_Q4_0] = QK4_0,
-    [GGML_TYPE_Q4_1] = QK4_1,
-    [GGML_TYPE_Q5_0] = QK5_0,
-    [GGML_TYPE_Q5_1] = QK5_1,
-    [GGML_TYPE_Q8_0] = QK8_0,
-    [GGML_TYPE_Q8_1] = QK8_1,
-#ifdef GGML_USE_K_QUANTS
-    [GGML_TYPE_Q2_K] = QK_K,
-    [GGML_TYPE_Q3_K] = QK_K,
-    [GGML_TYPE_Q4_K] = QK_K,
-    [GGML_TYPE_Q5_K] = QK_K,
-    [GGML_TYPE_Q6_K] = QK_K,
-    [GGML_TYPE_Q8_K] = QK_K,
-#endif
-    [GGML_TYPE_I8]   = 1,
-    [GGML_TYPE_I16]  = 1,
-    [GGML_TYPE_I32]  = 1,
-};
-static_assert(GGML_TYPE_COUNT == 19, "GGML_BLCK_SIZE is outdated");
-
-static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
-    [GGML_TYPE_F32]  = sizeof(float),
-    [GGML_TYPE_F16]  = sizeof(ggml_fp16_t),
-    [GGML_TYPE_Q4_0] = sizeof(block_q4_0),
-    [GGML_TYPE_Q4_1] = sizeof(block_q4_1),
-    [GGML_TYPE_Q5_0] = sizeof(block_q5_0),
-    [GGML_TYPE_Q5_1] = sizeof(block_q5_1),
-    [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
-    [GGML_TYPE_Q8_1] = sizeof(block_q8_1),
-#ifdef GGML_USE_K_QUANTS
-    [GGML_TYPE_Q2_K] = sizeof(block_q2_K),
-    [GGML_TYPE_Q3_K] = sizeof(block_q3_K),
-    [GGML_TYPE_Q4_K] = sizeof(block_q4_K),
-    [GGML_TYPE_Q5_K] = sizeof(block_q5_K),
-    [GGML_TYPE_Q6_K] = sizeof(block_q6_K),
-    [GGML_TYPE_Q8_K] = sizeof(block_q8_K),
-#endif
-    [GGML_TYPE_I8]   = sizeof(int8_t),
-    [GGML_TYPE_I16]  = sizeof(int16_t),
-    [GGML_TYPE_I32]  = sizeof(int32_t),
-};
-static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_SIZE is outdated");
-
-
-static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
-    [GGML_TYPE_F32]  = "f32",
-    [GGML_TYPE_F16]  = "f16",
-    [GGML_TYPE_Q4_0] = "q4_0",
-    [GGML_TYPE_Q4_1] = "q4_1",
-    [GGML_TYPE_Q5_0] = "q5_0",
-    [GGML_TYPE_Q5_1] = "q5_1",
-    [GGML_TYPE_Q8_0] = "q8_0",
-    [GGML_TYPE_Q8_1] = "q8_1",
-    [GGML_TYPE_Q2_K] = "q2_K",
-    [GGML_TYPE_Q3_K] = "q3_K",
-    [GGML_TYPE_Q4_K] = "q4_K",
-    [GGML_TYPE_Q5_K] = "q5_K",
-    [GGML_TYPE_Q6_K] = "q6_K",
-    [GGML_TYPE_Q8_K] = "q8_K",
-    [GGML_TYPE_I8]   = "i8",
-    [GGML_TYPE_I16]  = "i16",
-    [GGML_TYPE_I32]  = "i32",
-};
-static_assert(GGML_TYPE_COUNT == 19, "GGML_TYPE_NAME is outdated");
-
-static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
-    [GGML_TYPE_F32]  = false,
-    [GGML_TYPE_F16]  = false,
-    [GGML_TYPE_Q4_0] = true,
-    [GGML_TYPE_Q4_1] = true,
-    [GGML_TYPE_Q5_0] = true,
-    [GGML_TYPE_Q5_1] = true,
-    [GGML_TYPE_Q8_0] = true,
-    [GGML_TYPE_Q8_1] = true,
-    [GGML_TYPE_Q2_K] = true,
-    [GGML_TYPE_Q3_K] = true,
-    [GGML_TYPE_Q4_K] = true,
-    [GGML_TYPE_Q5_K] = true,
-    [GGML_TYPE_Q6_K] = true,
-    [GGML_TYPE_Q8_K] = true,
-    [GGML_TYPE_I8]   = false,
-    [GGML_TYPE_I16]  = false,
-    [GGML_TYPE_I32]  = false,
-};
-static_assert(GGML_TYPE_COUNT == 19, "GGML_IS_QUANTIZED is outdated");
-
 static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "NONE",
 
@@ -4110,29 +4091,33 @@ size_t ggml_nbytes(const struct ggml_tensor * tensor) {
     //
     // is enough, but just in case, adding the second part
 
-    return GGML_PAD(MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]), GGML_MEM_ALIGN);
+    return GGML_PAD(MAX(tensor->ne[3]*tensor->nb[3], ggml_nelements(tensor)*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type), GGML_MEM_ALIGN);
 }
 
 size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
-    return (nrows_split*tensor->ne[0]*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
+    return (nrows_split*tensor->ne[0]*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type);
 }
 
 int ggml_blck_size(enum ggml_type type) {
-    return GGML_BLCK_SIZE[type];
+    return type_traits[type].blck_size;
 }
 
 size_t ggml_type_size(enum ggml_type type) {
-    return GGML_TYPE_SIZE[type];
+    return type_traits[type].type_size;
 }
 
 float ggml_type_sizef(enum ggml_type type) {
-    return ((float)(GGML_TYPE_SIZE[type]))/GGML_BLCK_SIZE[type];
+    return ((float)(type_traits[type].type_size))/type_traits[type].blck_size;
 }
 
 const char * ggml_type_name(enum ggml_type type) {
-    return GGML_TYPE_NAME[type];
+    return type_traits[type].type_name;
+}
+
+bool ggml_is_quantized(enum ggml_type type) {
+    return type_traits[type].is_quantized;
 }
 
 const char * ggml_op_name(enum ggml_op op) {
@@ -4144,7 +4129,7 @@ const char * ggml_op_symbol(enum ggml_op op) {
 }
 
 size_t ggml_element_size(const struct ggml_tensor * tensor) {
-    return GGML_TYPE_SIZE[tensor->type];
+    return ggml_type_size(tensor->type);
 }
 
 static inline bool ggml_is_scalar(const struct ggml_tensor * tensor) {
@@ -4182,10 +4167,6 @@ static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct
         (t0->ne[3] == t1->ne[3]);
 }
 
-bool ggml_is_quantized(enum ggml_type type) {
-    return GGML_IS_QUANTIZED[type];
-}
-
 enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
     enum ggml_type wtype = GGML_TYPE_COUNT;
 
@@ -4223,8 +4204,8 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
-        tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
-        tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/GGML_BLCK_SIZE[tensor->type] &&
+        tensor->nb[0] == ggml_type_size(tensor->type) &&
+        tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
         tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
@@ -4233,7 +4214,7 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
-        tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+        tensor->nb[0] == ggml_type_size(tensor->type) &&
         tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
@@ -4248,7 +4229,7 @@ static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
     return
-        tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+        tensor->nb[0] == ggml_type_size(tensor->type) &&
         tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
         tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
 }
@@ -4567,7 +4548,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
     size_t data_size = 0;
 
     if (data == NULL && !ctx->no_alloc) {
-        data_size += GGML_TYPE_SIZE[type]*(ne[0]/GGML_BLCK_SIZE[type]);
+        data_size += ggml_type_size(type)*(ne[0]/ggml_blck_size(type));
         for (int i = 1; i < n_dims; i++) {
             data_size *= ne[i];
         }
@@ -4622,8 +4603,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         result->ne[i] = ne[i];
     }
 
-    result->nb[0] = GGML_TYPE_SIZE[type];
-    result->nb[1] = result->nb[0]*(result->ne[0]/GGML_BLCK_SIZE[type]);
+    result->nb[0] = ggml_type_size(type);
+    result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));
     for (int i = 2; i < GGML_MAX_DIMS; i++) {
         result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
     }
@@ -7745,7 +7726,7 @@ static void ggml_compute_forward_dup_same_cont(
         memcpy(
             ((char *)  dst->data + ie0*nb0),
             ((char *) src0->data + ie0*nb00),
-            (ie1 - ie0) * GGML_TYPE_SIZE[src0->type]);
+            (ie1 - ie0) * ggml_type_size(src0->type));
     }
 
 }
@@ -7779,7 +7760,7 @@ static void ggml_compute_forward_dup_f16(
 
     if (src0->type == dst->type &&
         ne00 == ne0 &&
-        nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
+        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
         // copy by rows
         const size_t rs = ne00*nb00;
         for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7837,7 +7818,7 @@ static void ggml_compute_forward_dup_f16(
                 float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
 
                 size_t id = 0;
-                size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
+                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
                 char * dst_ptr = (char *) dst->data;
 
                 for (int i03 = 0; i03 < ne03; i03++) {
@@ -8050,7 +8031,7 @@ static void ggml_compute_forward_dup_f32(
 
     if (src0->type == dst->type &&
         ne00 == ne0 &&
-        nb00 == GGML_TYPE_SIZE[src0->type] && nb0 == GGML_TYPE_SIZE[dst->type]) {
+        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
         // copy by rows
         const size_t rs = ne00*nb00;
         for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -8089,7 +8070,7 @@ static void ggml_compute_forward_dup_f32(
                 ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
 
                 size_t id = 0;
-                size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
+                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
                 char * dst_ptr = (char *) dst->data;
 
                 for (int i03 = 0; i03 < ne03; i03++) {
@@ -8501,7 +8482,7 @@ static void ggml_compute_forward_add_q_f32(
     ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
 
     // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
+    GGML_ASSERT(nb00 == ggml_type_size(type));
     GGML_ASSERT(nb10 == sizeof(float));
 
     // dst cannot be transposed or permuted
@@ -8775,7 +8756,7 @@ static void ggml_compute_forward_add1_q_f32(
     ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
 
     // we don't support permuted src0
-    GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
+    GGML_ASSERT(nb00 == ggml_type_size(type));
 
     // dst cannot be transposed or permuted
     GGML_ASSERT(nb0 <= nb1);
@@ -10629,7 +10610,7 @@ static void ggml_compute_forward_mul_mat(
     GGML_ASSERT(ne3 == ne13);
 
     // we don't support permuted src0 or src1
-    GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
+    GGML_ASSERT(nb00 == ggml_type_size(type));
     GGML_ASSERT(nb10 == sizeof(float));
 
     // dst cannot be transposed or permuted
@@ -10712,7 +10693,7 @@ static void ggml_compute_forward_mul_mat(
     if (params->type == GGML_TASK_INIT) {
         if (src1->type != vec_dot_type) {
             char * wdata = params->wdata;
-            const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
+            const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
 
             for (int64_t i13 = 0; i13 < ne13; ++i13) {
                 for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -10732,7 +10713,7 @@ static void ggml_compute_forward_mul_mat(
     }
 
     const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
-    const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
+    const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
 
     const int64_t nr0 = ne01;           // src0 rows
     const int64_t nr1 = ne11*ne12*ne13; // src1 rows
@@ -11205,7 +11186,7 @@ static void ggml_compute_forward_get_rows_q(
 
     assert( dst->ne[0] == nc);
     assert( dst->ne[1] == nr);
-    assert(src0->nb[0] == GGML_TYPE_SIZE[type]);
+    assert(src0->nb[0] == ggml_type_size(type));
 
     for (int i = 0; i < nr; ++i) {
         const int r = ((int32_t *) src1->data)[i];
@@ -16382,7 +16363,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 
                     size_t cur = 0;
                     if (ggml_is_quantized(node->type)) {
-                        cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
+                        cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
                     }
 
                     work_size = MAX(work_size, cur);
@@ -16395,7 +16376,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                     size_t cur = 0;
 
                     if (ggml_is_quantized(node->src[0]->type)) {
-                        cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[0]->ne[0] * n_tasks;
+                        cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
 
                     work_size = MAX(work_size, cur);
@@ -16407,7 +16388,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                     size_t cur = 0;
 
                     if (ggml_is_quantized(node->src[0]->type)) {
-                        cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[1]->ne[0] * n_tasks;
+                        cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
                     }
 
                     work_size = MAX(work_size, cur);
@@ -16490,12 +16471,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                                      //       the threads are still spinning
                         if (node->src[0]->type != GGML_TYPE_F32) {
                             // here we need memory just for single 2D matrix from src0
-                            cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src[0]->ne[0]*node->src[0]->ne[1]);
+                            cur = ggml_type_size(GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]);
                         }
                     } else
 #endif
                     if (node->src[1]->type != vec_dot_type) {
-                        cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src[1])/GGML_BLCK_SIZE[vec_dot_type];
+                        cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type);
                     } else {
                         cur = 0;
                     }
@@ -18301,8 +18282,8 @@ enum ggml_opt_result ggml_opt_resume(
         struct ggml_tensor * f) {
 
     // build forward + backward compute graphs
-    struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
-    struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / GGML_TYPE_SIZE[GGML_TYPE_I32]+ (sizeof(struct ggml_cgraph) % GGML_TYPE_SIZE[GGML_TYPE_I32] ? 1 : 0));
+    struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
+    struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32)+ (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
 
     struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
     struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
diff --git a/ggml.h b/ggml.h
index bdbd12800433242dec1e6dac8c96875540dd19e7..3a946dbdc44d72497c89cfe8cb9f878a33ddcc60 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1740,6 +1740,10 @@ extern "C" {
     typedef void (*ggml_vec_dot_t)   (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
 
     typedef struct {
+        const char      * type_name;
+        int               blck_size;
+        size_t            type_size;
+        bool              is_quantized;
         ggml_to_float_t   to_float;
         ggml_from_float_t from_float;
         ggml_from_float_t from_float_reference;
@@ -1747,7 +1751,7 @@ extern "C" {
         enum ggml_type    vec_dot_type;
     } ggml_type_traits_t;
 
-    ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type i);
+    ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
 
 #ifdef  __cplusplus
 }