]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : opt-in compile flag for BF16 (llama/10218)
authorGeorgi Gerganov <redacted>
Fri, 8 Nov 2024 19:59:46 +0000 (21:59 +0200)
committerGeorgi Gerganov <redacted>
Fri, 15 Nov 2024 13:21:04 +0000 (15:21 +0200)
* metal : opt-in compile flag for BF16

ggml-ci

* ci : use BF16

ggml-ci

* swift : switch back to v12

* metal : has_float -> use_float

ggml-ci

* metal : fix BF16 check in MSL

ggml-ci

ggml/CMakeLists.txt
ggml/src/CMakeLists.txt
ggml/src/ggml-metal.m
ggml/src/ggml-metal.metal

index 6866a25d3d4455568822530d1a99635f0d423dfe..81b7a02f5192f999fcb4b20fa7a6ed259aa0eb27 100644 (file)
@@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE                 "ggml: enable Vulkan validation"
 option(GGML_VULKAN_RUN_TESTS                "ggml: run Vulkan tests"                          OFF)
 option(GGML_KOMPUTE                         "ggml: use Kompute"                               OFF)
 option(GGML_METAL                           "ggml: use Metal"                                 ${GGML_METAL_DEFAULT})
+option(GGML_METAL_USE_BF16                  "ggml: use bfloat if available"                   OFF)
 option(GGML_METAL_NDEBUG                    "ggml: disable Metal debugging"                   OFF)
 option(GGML_METAL_SHADER_DEBUG              "ggml: compile Metal with -fno-fast-math"         OFF)
 option(GGML_METAL_EMBED_LIBRARY             "ggml: embed Metal library"                       ${GGML_METAL})
index 34b81bd7fdda1ea7b3cab69454b65f474d472b01..6c5b816d2f5e7d18fb4c972a7fbe5f3ef70aa925 100644 (file)
@@ -58,6 +58,10 @@ if (GGML_METAL)
         add_compile_definitions(GGML_METAL_NDEBUG)
     endif()
 
+    if (GGML_METAL_USE_BF16)
+        add_compile_definitions(GGML_METAL_USE_BF16)
+    endif()
+
     # copy ggml-common.h and ggml-metal.metal to bin directory
     configure_file(ggml-common.h    ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h    COPYONLY)
     configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
index e19397fd2de70702c4a8807ccd4c7a36ec3aaa94..10d59cb9f7aa83985f13b6a2cd6bb9bcff1a5447 100644 (file)
@@ -39,6 +39,7 @@ static struct ggml_backend_metal_device_context {
     bool has_simdgroup_reduction;
     bool has_simdgroup_mm;
     bool has_bfloat;
+    bool use_bfloat;
 
     char name[128];
 } g_ggml_ctx_dev_main = {
@@ -47,6 +48,7 @@ static struct ggml_backend_metal_device_context {
     /*.has_simdgroup_reduction =*/ false,
     /*.has_simdgroup_mm        =*/ false,
     /*.has_bfloat              =*/ false,
+    /*.use_bfloat              =*/ false,
     /*.name                    =*/ "",
 };
 
@@ -65,6 +67,12 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
         ctx->has_bfloat  = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
         ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
 
+#if defined(GGML_METAL_USE_BF16)
+        ctx->use_bfloat = ctx->has_bfloat;
+#else
+        ctx->use_bfloat = false;
+#endif
+
         strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
     }
 
@@ -504,6 +512,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
                 // dictionary of preprocessor macros
                 NSMutableDictionary * prep = [NSMutableDictionary dictionary];
 
+                if (ctx_dev->use_bfloat) {
+                    [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
+                }
+
                 MTLCompileOptions * options = [MTLCompileOptions new];
                 options.preprocessorMacros = prep;
 
@@ -556,7 +568,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
 
     GGML_LOG_INFO("%s: simdgroup reduction   = %s\n", __func__, ctx_dev->has_simdgroup_reduction     ? "true" : "false");
     GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm            ? "true" : "false");
-    GGML_LOG_INFO("%s: bfloat                = %s\n", __func__, ctx_dev->has_bfloat                  ? "true" : "false");
+    GGML_LOG_INFO("%s: has bfloat            = %s\n", __func__, ctx_dev->has_bfloat                  ? "true" : "false");
+    GGML_LOG_INFO("%s: use bfloat            = %s\n", __func__, ctx_dev->use_bfloat                  ? "true" : "false");
     GGML_LOG_INFO("%s: hasUnifiedMemory      = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
 
     ctx->capture_next_compute = false;
@@ -608,7 +621,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
 
         const bool has_simdgroup_mm        = ctx_dev->has_simdgroup_mm;
         const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
-        const bool has_bfloat              = ctx_dev->has_bfloat;
+        const bool use_bfloat              = ctx_dev->use_bfloat;
 
         // simd_sum and simd_max requires MTLGPUFamilyApple7
 
@@ -644,7 +657,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,               diag_mask_inf_8,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,                  get_rows_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,                  get_rows_f16,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,                 get_rows_bf16,                  has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,                 get_rows_bf16,                  use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,                 get_rows_q4_0,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,                 get_rows_q4_1,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,                 get_rows_q5_0,                  true);
@@ -671,10 +684,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                  ssm_conv_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                  ssm_scan_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,               mul_mv_bf16_f32,                has_simdgroup_reduction && has_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,          mul_mv_bf16_f32_1row,           has_simdgroup_reduction && has_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,            mul_mv_bf16_f32_l4,             has_simdgroup_reduction && has_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,              mul_mv_bf16_bf16,               has_simdgroup_reduction && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,               mul_mv_bf16_f32,                has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,          mul_mv_bf16_f32_1row,           has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,            mul_mv_bf16_f32_l4,             has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,              mul_mv_bf16_bf16,               has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              has_simdgroup_reduction);
@@ -703,7 +716,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         has_simdgroup_reduction);
       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           has_simdgroup_reduction);
       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,            mul_mv_id_bf16_f32,             has_simdgroup_reduction && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,            mul_mv_id_bf16_f32,             has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             has_simdgroup_reduction);
@@ -725,7 +738,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,               mul_mm_bf16_f32,                has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,               mul_mm_bf16_f32,                has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                has_simdgroup_mm);
@@ -747,7 +760,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,            mul_mm_id_bf16_f32,             has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,            mul_mm_id_bf16_f32,             has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             has_simdgroup_mm);
@@ -788,12 +801,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,       flash_attn_ext_bf16_h64,        has_simdgroup_mm && has_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,       flash_attn_ext_bf16_h80,        has_simdgroup_mm && has_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,       flash_attn_ext_bf16_h96,        has_simdgroup_mm && has_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,      flash_attn_ext_bf16_h112,       has_simdgroup_mm && has_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,      flash_attn_ext_bf16_h128,       has_simdgroup_mm && has_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,      flash_attn_ext_bf16_h256,       has_simdgroup_mm && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,       flash_attn_ext_bf16_h64,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,       flash_attn_ext_bf16_h80,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,       flash_attn_ext_bf16_h96,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,      flash_attn_ext_bf16_h112,       has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,      flash_attn_ext_bf16_h128,       has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,      flash_attn_ext_bf16_h256,       has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,       flash_attn_ext_q4_0_h64,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,       flash_attn_ext_q4_0_h80,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,       flash_attn_ext_q4_0_h96,        has_simdgroup_mm);
@@ -825,14 +838,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,      flash_attn_ext_q8_0_h128,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,      flash_attn_ext_q8_0_h256,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,  flash_attn_ext_vec_bf16_h128,   has_simdgroup_reduction && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,  flash_attn_ext_vec_bf16_h128,   has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,  flash_attn_ext_vec_q4_0_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,  flash_attn_ext_vec_q4_1_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,  flash_attn_ext_vec_q5_0_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,  flash_attn_ext_vec_q5_1_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,  flash_attn_ext_vec_q8_0_h128,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,  flash_attn_ext_vec_bf16_h256,   has_simdgroup_reduction && has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,  flash_attn_ext_vec_bf16_h256,   has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,  flash_attn_ext_vec_q4_0_h256,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,  flash_attn_ext_vec_q4_1_h256,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   has_simdgroup_reduction);
@@ -840,11 +853,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,  flash_attn_ext_vec_q8_0_h256,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,                  cpy_f32_bf16,                   has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,                  cpy_f32_bf16,                   use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                   cpy_f16_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,                  cpy_bf16_f32,                   has_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,                 cpy_bf16_bf16,                  has_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,                  cpy_bf16_f32,                   use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,                 cpy_bf16_bf16,                  use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                  cpy_f32_q4_0,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                  cpy_f32_q4_1,                   true);
@@ -936,9 +949,9 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs
 static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
     const bool has_simdgroup_mm        = ctx_dev->has_simdgroup_mm;
     const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
-    const bool has_bfloat              = ctx_dev->has_bfloat;
+    const bool use_bfloat              = ctx_dev->use_bfloat;
 
-    if (!has_bfloat) {
+    if (!use_bfloat) {
         for (size_t i = 0, n = 3; i < n; ++i) {
             if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
                 return false;
index 89f12724d3095eab21e2890e412bd71730e1741a..7e151741466b6aa11e1dd0305087643d8c52b514 100644 (file)
@@ -18,11 +18,11 @@ using namespace metal;
 //   .../usr/bin/metal -dM -E -c                             ggml/src/ggml-metal.metal
 //   .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal.metal
 //
-#if __METAL_VERSION__ < 310
-#define GGML_METAL_NO_BFLOAT
+#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16)
+#undef GGML_METAL_USE_BF16
 #endif
 
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 typedef matrix<bfloat, 4, 4> bfloat4x4;
 #endif
 
@@ -41,7 +41,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
     reg = (type4x4)(*src);
 }
 
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template <typename type4x4>
 void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
     reg = (type4x4)(*src);
@@ -2082,7 +2082,7 @@ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
 template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t kernel_mul_mv<float,  float4,  float,  float4>;
 template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   float,  float4>;
 template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   half,   half4>;
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float,  float4>;
 template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
 #endif
@@ -2155,7 +2155,7 @@ kernel void kernel_mul_mv_1row(
 typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
 
 template [[host_name("kernel_mul_mv_f16_f32_1row")]]  kernel mul_mv_1row_t kernel_mul_mv_1row<half,   half4>;
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
 #endif
 
@@ -2217,7 +2217,7 @@ kernel void kernel_mul_mv_l4(
 typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
 
 template [[host_name("kernel_mul_mv_f16_f32_l4")]]  kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
 #endif
 
@@ -3249,7 +3249,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]]  kernel flash_attn_ext_
 template [[host_name("kernel_flash_attn_ext_f16_h128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  128>;
 template [[host_name("kernel_flash_attn_ext_f16_h256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  256>;
 
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64>;
 template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80>;
 template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96>;
@@ -3648,7 +3648,7 @@ kernel void kernel_flash_attn_ext_vec(
 typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
 
 template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,     1, dequantize_f16,  128>;
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,   1, dequantize_bf16, 128>;
 #endif
 template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0,  2, dequantize_q4_0, 128>;
@@ -3658,7 +3658,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_
 template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0,  2, dequantize_q8_0, 128>;
 
 template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,     1, dequantize_f16,  256>;
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,   1, dequantize_bf16, 256>;
 #endif
 template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0,  2, dequantize_q4_0, 256>;
@@ -3715,12 +3715,12 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
 
 template [[host_name("kernel_cpy_f32_f32")]]   kernel kernel_cpy_t kernel_cpy<float,  float>;
 template [[host_name("kernel_cpy_f32_f16")]]   kernel kernel_cpy_t kernel_cpy<float,  half>;
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_cpy_f32_bf16")]]  kernel kernel_cpy_t kernel_cpy<float,  bfloat>;
 #endif
 template [[host_name("kernel_cpy_f16_f32")]]   kernel kernel_cpy_t kernel_cpy<half,   float>;
 template [[host_name("kernel_cpy_f16_f16")]]   kernel kernel_cpy_t kernel_cpy<half,   half>;
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_cpy_bf16_f32")]]  kernel kernel_cpy_t kernel_cpy<bfloat, float>;
 template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
 #endif
@@ -6628,7 +6628,7 @@ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
 
 template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_f_t kernel_get_rows_f<float>;
 template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_f_t kernel_get_rows_f<half>;
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
 #endif
 
@@ -6662,7 +6662,7 @@ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, de
 
 template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32>;
 template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16>;
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_mul_mm_bf16_f32")]]    kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4,     1,     dequantize_bf16>;
 #endif
 template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0>;
@@ -6693,7 +6693,7 @@ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
 
 template [[host_name("kernel_mul_mm_id_f32_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<float4x4,      1,     dequantize_f32>;
 template [[host_name("kernel_mul_mm_id_f16_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<half4x4,       1,     dequantize_f16>;
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_mul_mm_id_bf16_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4,     1,     dequantize_bf16>;
 #endif
 template [[host_name("kernel_mul_mm_id_q4_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0,    2,     dequantize_q4_0>;
@@ -6919,7 +6919,7 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float
 
 template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
 template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
-#if !defined(GGML_METAL_NO_BFLOAT)
+#if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
 #endif
 template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;