]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : fix BLAS with unsupported types (llama/9775)
authorDiego Devesa <redacted>
Tue, 8 Oct 2024 12:21:43 +0000 (14:21 +0200)
committerGeorgi Gerganov <redacted>
Fri, 1 Nov 2024 08:19:05 +0000 (10:19 +0200)
* ggml : do not use BLAS with types without to_float

* ggml : return pointer from ggml_internal_get_type_traits to avoid unnecessary copies

* ggml : rename ggml_internal_get_type_traits -> ggml_get_type_traits

it's not really internal if everybody uses it

ggml/include/ggml.h
ggml/src/ggml-backend.cpp
ggml/src/ggml-blas.cpp
ggml/src/ggml-vulkan.cpp
ggml/src/ggml.c

index 8d36b3d4d42e38aaed193e74636d4c38c9491897..14f4eb9bd128917522c911c6ed2a971b427d1d2e 100644 (file)
@@ -2536,7 +2536,7 @@ extern "C" {
     typedef void (*ggml_gemm_t)     (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
                                        const void * GGML_RESTRICT y, int nr, int nc);
 
-    typedef struct {
+    struct ggml_type_traits {
         const char             * type_name;
         int64_t                  blck_size;
         int64_t                  blck_size_interleave; // interleave elements in blocks
@@ -2552,9 +2552,9 @@ extern "C" {
         int64_t                  ncols; // number of columns to process simultaneously
         ggml_gemv_t              gemv;
         ggml_gemm_t              gemm;
-    } ggml_type_traits_t;
+    };
 
-    GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
+    GGML_API const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type);
 
 #ifdef  __cplusplus
 }
index fbd49d13dda14247c29a87ae366a4411c4a1067f..627b4dbc7873213b2c2a923d6572ded397afdc6a 100644 (file)
@@ -1177,7 +1177,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
                 op->type != GGML_TYPE_IQ1_S   &&
                 op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
         case GGML_OP_MUL_MAT:
-            return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
+            return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_get_type_traits(op->src[0]->type)->vec_dot_type;
         case GGML_OP_ROPE_BACK:
             return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
         case GGML_OP_IM2COL_BACK:
index 0c6574de500a69e0d9252c3a15e7368f3dd28c2d..55f7245861105eada9dfa7777042159edb32a908 100644 (file)
@@ -65,8 +65,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
 
     // convert src0 to float
     if (type != GGML_TYPE_F32) {
-        ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type);
-        ggml_to_float_t const to_float = type_traits.to_float;
+        const auto * type_traits = ggml_get_type_traits(type);
+        ggml_to_float_t const to_float = type_traits->to_float;
 
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -420,19 +420,21 @@ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const s
             // TODO: find the optimal value
             const int64_t min_batch = 32;
 
-            return (ggml_is_contiguous(src0) &&
-                    ggml_is_contiguous(src1) &&
-                    src1->type == GGML_TYPE_F32 &&
-                    (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch));
+            return ggml_is_contiguous(src0) &&
+                   ggml_is_contiguous(src1) &&
+                   src1->type == GGML_TYPE_F32 &&
+                   (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
+                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
         }
 
         case GGML_OP_OUT_PROD:
-            return (op->src[0]->type == GGML_TYPE_F32 &&
-                    op->src[1]->type == GGML_TYPE_F32 &&
-                    ggml_is_matrix(src0) &&
-                    ggml_is_matrix(src1) &&
-                    ggml_is_contiguous(src0) &&
-                    (ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
+            return op->src[0]->type == GGML_TYPE_F32 &&
+                   op->src[1]->type == GGML_TYPE_F32 &&
+                   ggml_is_matrix(src0) &&
+                   ggml_is_matrix(src1) &&
+                   ggml_is_contiguous(src0) &&
+                   (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
+                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
 
         default:
             return false;
index 30bd376da61882aca8361f98a4333ae4933c3961..374c6ecd7ade5b0789ab1bdd0011603a06ee0618 100644 (file)
@@ -5287,9 +5287,9 @@ static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, gg
         return;
     }
 
-    ggml_type_traits_t tt = ggml_internal_get_type_traits(quant);
+    const auto * tt = ggml_get_type_traits(quant);
 
-    ggml_to_float_t dequant_fn = tt.to_float;
+    ggml_to_float_t dequant_fn = tt->to_float;
 
     dequant_fn(from, to, ne);
 }
index 264ffb5195ebf0f7d525bb0ef5a58d2ccfc87c37..439978eac29cdba26931f58b2c62114ad1de00df 100644 (file)
@@ -730,7 +730,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
 static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
 static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
 
-static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
+static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
     [GGML_TYPE_I8] = {
         .type_name                = "i8",
         .blck_size                = 1,
@@ -1152,9 +1152,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
 };
 
 // For internal test use
-ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
+const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
     GGML_ASSERT(type < GGML_TYPE_COUNT);
-    return type_traits[type];
+    return &type_traits[type];
 }
 
 //