]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : fix checks for available FA kernels (#15700)
authorGeorgi Gerganov <redacted>
Sun, 31 Aug 2025 16:43:30 +0000 (19:43 +0300)
committerGitHub <redacted>
Sun, 31 Aug 2025 16:43:30 +0000 (19:43 +0300)
* metal : fix checks for available FA kernels

ggml-ci

* cont : fix comment [no ci]

ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index 1f93633d91f260e6a9821904b094a0c1a4f1227a..3d16a1dcd461e10fbedf18490c6245c5ed6ee483 100644 (file)
@@ -523,13 +523,6 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
@@ -1562,13 +1555,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,        flash_attn_ext_q8_0_h256,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,      flash_attn_ext_vec_f16_h40,      has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,     flash_attn_ext_vec_bf16_h40,     has_simdgroup_reduction && use_bfloat);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,     flash_attn_ext_vec_q4_0_h40,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,     flash_attn_ext_vec_q4_1_h40,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,     flash_attn_ext_vec_q5_0_h40,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,     flash_attn_ext_vec_q5_1_h40,     has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,     flash_attn_ext_vec_q8_0_h40,     has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,      flash_attn_ext_vec_f16_h64,      has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,     flash_attn_ext_vec_bf16_h64,     has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,     flash_attn_ext_vec_q4_0_h64,     has_simdgroup_reduction);
@@ -1909,9 +1895,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_ARANGE:
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
-            if (op->src[0]->ne[0] == 32) {
-                // head size == 32 (e.g. bert-bge-small)
-                // TODO: not sure if it is worth adding kernels for this size
+            // for new head sizes, add checks here
+            if (op->src[0]->ne[0] != 40 &&
+                op->src[0]->ne[0] != 64 &&
+                op->src[0]->ne[0] != 80 &&
+                op->src[0]->ne[0] != 96 &&
+                op->src[0]->ne[0] != 112 &&
+                op->src[0]->ne[0] != 128 &&
+                op->src[0]->ne[0] != 192 &&
+                op->src[0]->ne[0] != 256) {
                 return false;
             }
             if (op->src[0]->ne[0] == 576) {
@@ -5138,10 +5130,8 @@ static int ggml_metal_encode_node(
 
                 bool use_vec_kernel = false;
 
-                // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
-                //       for now avoiding mainly to keep the number of templates/kernels a bit lower
-                //       these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
-                if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
+                // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
+                if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) {
                     switch (src1->type) {
                         case GGML_TYPE_F16:
                             {
@@ -5329,24 +5319,6 @@ static int ggml_metal_encode_node(
                     use_vec_kernel = true;
 
                     switch (ne00) {
-                        case 40:
-                            {
-                                switch (src1->type) {
-                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40].pipeline; break;
-                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40].pipeline; break;
-                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40].pipeline; break;
-                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40].pipeline; break;
-                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40].pipeline; break;
-                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40].pipeline; break;
-                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40].pipeline; break;
-                                    default:
-                                        {
-                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
-                                            GGML_LOG_ERROR("add template specialization for this type\n");
-                                            GGML_ABORT("add template specialization for this type");
-                                        }
-                                }
-                            } break;
                         case 64:
                             {
                                 switch (src1->type) {
index 4fa16c4a553d290d6a0f6bc368686eb596f13f96..9c5933d24a0e3140cc25f054af8953c6511bb053 100644 (file)
@@ -4803,6 +4803,9 @@ kernel void kernel_flash_attn_ext_vec(
         ushort3   ntg[[threads_per_threadgroup]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    static_assert(DK % 32 == 0, "DK must be divisible by 32");
+    static_assert(DV % 32 == 0, "DV must be divisible by 32");
+
     const short nsg = ntg.y; // number of simdgroups
     const short iwg = tgpig[2]%nwg;
 
@@ -5160,16 +5163,6 @@ kernel void kernel_flash_attn_ext_vec(
 
 typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
 
-template [[host_name("kernel_flash_attn_ext_vec_f16_h40")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  40, 40, 8>;
-#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,           1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 40, 40, 8>;
-#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0,        8, dequantize_q4_0_t4, block_q4_0,  8, dequantize_q4_0_t4, 40, 40, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1,        8, dequantize_q4_1_t4, block_q4_1,  8, dequantize_q4_1_t4, 40, 40, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0,        8, dequantize_q5_0_t4, block_q5_0,  8, dequantize_q5_0_t4, 40, 40, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1,        8, dequantize_q5_1_t4, block_q5_1,  8, dequantize_q5_1_t4, 40, 40, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0,        8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 40, 40, 8>;
-
 template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  64, 64, 8>;
 #if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4,           1, dequantize_bf16_t4, bfloat4,     1, dequantize_bf16_t4, 64, 64, 8>;