]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : disable FA kernel for HS=256 (llama/7556)
authorGeorgi Gerganov <redacted>
Mon, 27 May 2024 07:38:39 +0000 (10:38 +0300)
committerGeorgi Gerganov <redacted>
Wed, 29 May 2024 10:16:38 +0000 (13:16 +0300)
ggml-ci

src/ggml-metal.m
src/ggml-metal.metal

index c9e570dbf5a3a9faea1539b66cd1ce9e5cccfb16..15fb68fc489af1c1ca14cf74b87ba5966d5b54d0 100644 (file)
@@ -184,9 +184,9 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+  //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,     // https://github.com/ggerganov/llama.cpp/issues/7261
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
-    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
+  //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
     GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -634,9 +634,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        ctx->support_simdgroup_mm);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx->support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx->support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true);
@@ -770,6 +770,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
         case GGML_OP_LEAKY_RELU:
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
+            if (op->src[0]->ne[0] == 256) {
+                return false;
+            }
             return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
@@ -2573,7 +2576,7 @@ static enum ggml_status ggml_metal_graph_compute(
                                 case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
                                 case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
                                 case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
-                                case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+                              //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
                                 default:
                                           {
                                               GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2586,7 +2589,7 @@ static enum ggml_status ggml_metal_graph_compute(
 
                             switch (ne00) {
                                 case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
-                                case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+                              //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
                                 default:
                                           {
                                               GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
index 8ff70d7a79ca7a802118826831aab252079a6a66..ce51c74d5158d8a629bfd4ca377ac4ed64ea99c4 100644 (file)
@@ -2418,7 +2418,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
 template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
 template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
 template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
-template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
+//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
 
 template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
 kernel void kernel_flash_attn_ext_vec_f16(
@@ -2696,7 +2696,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
 }
 
 template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
-template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
+//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
 
 kernel void kernel_cpy_f16_f16(
         device  const half * src0,