typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
const void * GGML_RESTRICT y, int nr, int nc);
- typedef struct {
+ struct ggml_type_traits {
const char * type_name;
int64_t blck_size;
int64_t blck_size_interleave; // interleave elements in blocks
int64_t ncols; // number of columns to process simultaneously
ggml_gemv_t gemv;
ggml_gemm_t gemm;
- } ggml_type_traits_t;
+ };
- GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
+ GGML_API const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type);
#ifdef __cplusplus
}
op->type != GGML_TYPE_IQ1_S &&
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
case GGML_OP_MUL_MAT:
- return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
+ return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_get_type_traits(op->src[0]->type)->vec_dot_type;
case GGML_OP_ROPE_BACK:
return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
case GGML_OP_IM2COL_BACK:
// convert src0 to float
if (type != GGML_TYPE_F32) {
- ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type);
- ggml_to_float_t const to_float = type_traits.to_float;
+ const auto * type_traits = ggml_get_type_traits(type);
+ ggml_to_float_t const to_float = type_traits->to_float;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
// TODO: find the optimal value
const int64_t min_batch = 32;
- return (ggml_is_contiguous(src0) &&
- ggml_is_contiguous(src1) &&
- src1->type == GGML_TYPE_F32 &&
- (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch));
+ return ggml_is_contiguous(src0) &&
+ ggml_is_contiguous(src1) &&
+ src1->type == GGML_TYPE_F32 &&
+ (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
+ (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
}
case GGML_OP_OUT_PROD:
- return (op->src[0]->type == GGML_TYPE_F32 &&
- op->src[1]->type == GGML_TYPE_F32 &&
- ggml_is_matrix(src0) &&
- ggml_is_matrix(src1) &&
- ggml_is_contiguous(src0) &&
- (ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
+ return op->src[0]->type == GGML_TYPE_F32 &&
+ op->src[1]->type == GGML_TYPE_F32 &&
+ ggml_is_matrix(src0) &&
+ ggml_is_matrix(src1) &&
+ ggml_is_contiguous(src0) &&
+ (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
+ (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
default:
return false;
return;
}
- ggml_type_traits_t tt = ggml_internal_get_type_traits(quant);
+ const auto * tt = ggml_get_type_traits(quant);
- ggml_to_float_t dequant_fn = tt.to_float;
+ ggml_to_float_t dequant_fn = tt->to_float;
dequant_fn(from, to, ne);
}
static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
-static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
+static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_I8] = {
.type_name = "i8",
.blck_size = 1,
};
// For internal test use
-ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
+const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
GGML_ASSERT(type < GGML_TYPE_COUNT);
- return type_traits[type];
+ return &type_traits[type];
}
//