]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : add F32 support + update bench output
authorGeorgi Gerganov <redacted>
Fri, 15 Sep 2023 10:56:08 +0000 (13:56 +0300)
committerGeorgi Gerganov <redacted>
Fri, 15 Sep 2023 10:56:08 +0000 (13:56 +0300)
Makefile
extra/bench-all.sh
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h
whisper.cpp

index 2df511167bb42e12bbe94fd6b8421e786ecb8b4f..e35222579aef06db67b36288681d1e538a880898 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -186,6 +186,7 @@ ifndef WHISPER_NO_METAL
        ifeq ($(UNAME_S),Darwin)
                WHISPER_METAL := 1
 
+               CFLAGS   += -DGGML_USE_METAL
                CXXFLAGS += -DGGML_USE_METAL
                LDFLAGS  += -framework Foundation -framework Metal -framework MetalKit
        endif
index 352a2235abb6c1d7572f234f96b767b7bea0570d..8fd18b7d16a35064a386149f1379ea3fdf4c6c59 100755 (executable)
@@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
     printf "\n"
 fi
 
-printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
-printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
+printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
+printf "| %6s | %6s | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
 
 for model in "${models[@]}"; do
     # actual run
@@ -83,9 +83,13 @@ for model in "${models[@]}"; do
         config="$config COREML"
     fi
 
+    if [[ $system_info == *"METAL = 1"* ]]; then
+        config="$config METAL"
+    fi
+
     commit=$(git rev-parse --short HEAD)
 
     if [ $ret -eq 0 ]; then
-        printf "| <todo> | <todo> | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
+        printf "| <todo> | <todo> | %16s | %11s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
     fi
 done
index b438b83f9ffa1ec0475e9354f9901353ab4fd8dd..c5b6b8b9aae81b34fb5c6d14358693d195e6363f 100644 (file)
@@ -78,6 +78,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_q6_K);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
+    GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
@@ -89,6 +90,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
@@ -237,6 +239,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_q6_K);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
+        GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
@@ -248,6 +251,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
@@ -309,6 +313,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(get_rows_q6_K);
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(norm);
+    GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
@@ -320,6 +325,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
     GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
     GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
@@ -885,6 +891,7 @@ void ggml_metal_graph_compute(
                                 ne00%32 == 0 &&
                                 ne11 > 1) {
                                 switch (src0->type) {
+                                    case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break;
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
                                     case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
                                     case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
@@ -919,6 +926,11 @@ void ggml_metal_graph_compute(
 
                                 // use custom matrix x vector kernel
                                 switch (src0t) {
+                                    case GGML_TYPE_F32:
+                                        {
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
+                                            nrows = 4;
+                                        } break;
                                     case GGML_TYPE_F16:
                                         {
                                             nth0 = 32;
index 0db037c1636b33e8823b3dc0495822306d113534..3087ecda812d90d510ffc6859ff3d4dc52f12e7e 100644 (file)
@@ -523,6 +523,79 @@ kernel void kernel_mul_mat_q8_0_f32(
     }
 }
 
+#define N_F32_F32 4
+
+kernel void kernel_mul_mat_f32_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = tgpig.y*N_F32_F32;
+    const int64_t im = tgpig.z;
+
+    device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (float) x[i] * (float) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const float4 * x4 = (device const float4 *)x;
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12);
+            device const float4 * y4 = (device const float4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
 kernel void kernel_mul_mat_f16_f32_1row(
         device const  char * src0,
         device const  char * src1,
@@ -1399,13 +1472,13 @@ kernel void kernel_mul_mat_q4_K_f32(
         device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
-        constant   int64_t & ne01[[buffer(4)]],
-        constant   int64_t & ne02[[buffer(5)]],
-        constant   int64_t & ne10[[buffer(9)]],
-        constant   int64_t & ne12[[buffer(11)]],
-        constant   int64_t & ne0[[buffer(15)]],
-        constant   int64_t & ne1[[buffer(16)]],
-        constant   uint    & gqa[[buffer(17)]],
+        constant   int64_t & ne01 [[buffer(4)]],
+        constant   int64_t & ne02 [[buffer(5)]],
+        constant   int64_t & ne10 [[buffer(9)]],
+        constant   int64_t & ne12 [[buffer(11)]],
+        constant   int64_t & ne0  [[buffer(15)]],
+        constant   int64_t & ne1  [[buffer(16)]],
+        constant   uint    & gqa  [[buffer(17)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2268,6 +2341,7 @@ typedef void (mat_mm_t)(
         constant       uint & gqa,
         threadgroup uchar *, uint3, uint, uint);
 
+template [[host_name("kernel_mul_mm_f32_f32")]]  kernel mat_mm_t kernel_mul_mm<float4x4,   1,     dequantize_f32>;
 template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1,     dequantize_f16>;
 template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2,     dequantize_q4_0>;
 template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2,     dequantize_q4_1>;
diff --git a/ggml.c b/ggml.c
index c5b5dd65bb02218ccbaad190fdb279323b86c31d..b1d4678c011162f940c431b99bc127b61135aba1 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -20753,6 +20753,14 @@ int ggml_cpu_has_arm_fma(void) {
 #endif
 }
 
+int ggml_cpu_has_metal(void) {
+#if defined(GGML_USE_METAL)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_f16c(void) {
 #if defined(__F16C__)
     return 1;
diff --git a/ggml.h b/ggml.h
index c936823d661404434484bebf1c88ed0d2c822e48..62d19f387ffc5c76567631f26a8b715272f00601 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1961,6 +1961,7 @@ extern "C" {
     GGML_API int ggml_cpu_has_fma        (void);
     GGML_API int ggml_cpu_has_neon       (void);
     GGML_API int ggml_cpu_has_arm_fma    (void);
+    GGML_API int ggml_cpu_has_metal      (void);
     GGML_API int ggml_cpu_has_f16c       (void);
     GGML_API int ggml_cpu_has_fp16_va    (void);
     GGML_API int ggml_cpu_has_wasm_simd  (void);
index 23ebd7e95c54c1e6d0e2daa613b38c3e28358baf..1224be9bf51bcb4ae10dc5adfef9015aca807da0 100644 (file)
@@ -3669,6 +3669,7 @@ const char * whisper_print_system_info(void) {
     s += "FMA = "       + std::to_string(ggml_cpu_has_fma())       + " | ";
     s += "NEON = "      + std::to_string(ggml_cpu_has_neon())      + " | ";
     s += "ARM_FMA = "   + std::to_string(ggml_cpu_has_arm_fma())   + " | ";
+    s += "METAL = "     + std::to_string(ggml_cpu_has_metal())     + " | ";
     s += "F16C = "      + std::to_string(ggml_cpu_has_f16c())      + " | ";
     s += "FP16_VA = "   + std::to_string(ggml_cpu_has_fp16_va())   + " | ";
     s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";