]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CANN: Add the basic supports of Flash Attention kernel (#13627)
authorBizhao Shi <redacted>
Mon, 26 May 2025 02:20:18 +0000 (10:20 +0800)
committerGitHub <redacted>
Mon, 26 May 2025 02:20:18 +0000 (10:20 +0800)
* cann: add the basic FA support

* cann: update the readme

* cann: update the FlashAttention with PSEShift

* cann: update the input parameters in FA

* cann: update the alibi with max_bias

* cann: add the constrints of softcap

* cann: update the docs CANN.md

* cann: update the docs CANN.md

* cann: fix typo of CANN.md

* cann: add some comments and update the CANN.md

* cann: update the CANN.md

* cann: update the inner precise for fusedInferAttention

* cann: update the constraints of flash_attn_ext on ggml-cann.cpp

* cann: clean the whitespace

* cann: clean the whitespace

* cann: add a new endline

docs/backend/CANN.md [changed mode: 0644->0755]
ggml/src/ggml-cann/CMakeLists.txt [changed mode: 0644->0755]
ggml/src/ggml-cann/Doxyfile [changed mode: 0644->0755]
ggml/src/ggml-cann/acl_tensor.cpp [changed mode: 0644->0755]
ggml/src/ggml-cann/acl_tensor.h [changed mode: 0644->0755]
ggml/src/ggml-cann/aclnn_ops.cpp [changed mode: 0644->0755]
ggml/src/ggml-cann/aclnn_ops.h [changed mode: 0644->0755]
ggml/src/ggml-cann/common.h [changed mode: 0644->0755]
ggml/src/ggml-cann/ggml-cann.cpp [changed mode: 0644->0755]

old mode 100644 (file)
new mode 100755 (executable)
index e172ec5..a5ba617
@@ -280,6 +280,15 @@ cmake --build build --config release
 ### **GitHub contribution**:
 Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay.
 
+## Updates
+### Basic Flash Attention Support
+The basic FA kernel with aclnnops has been added in aclnn_ops.cpp.
+Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap.
+Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future.
+
+Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn).
+
+We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request.
 
 ## TODO
 - Support more models and data types.
old mode 100644 (file)
new mode 100755 (executable)
old mode 100644 (file)
new mode 100755 (executable)
old mode 100644 (file)
new mode 100755 (executable)
index f5462c5..f311864
@@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
             return ACL_FLOAT;
         case GGML_TYPE_F16:
             return ACL_FLOAT16;
+        case GGML_TYPE_BF16:
+            return ACL_BF16;
         case GGML_TYPE_I8:
             return ACL_INT8;
         case GGML_TYPE_I16:
old mode 100644 (file)
new mode 100755 (executable)
old mode 100644 (file)
new mode 100755 (executable)
index 9c67664..437ece2
@@ -66,6 +66,7 @@
 #include <aclnnop/aclnn_gt_scalar.h>
 #include <aclnnop/aclnn_pow.h>
 #include <aclnnop/aclnn_grouped_matmul_v2.h>
+#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
 #include <float.h>
 
 #include <cmath>
 #include <vector>
 
 #include "ggml-impl.h"
+#include "ggml.h"
 
 #define GGML_COMMON_DECL_C
 
 #include "../ggml-common.h"
 
+
 void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
                  aclTensor ** acl_src1, aclTensor ** acl_dst) {
     GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));
@@ -2861,3 +2864,330 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
             break;
     }
 }
+
+void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
+
+    ggml_tensor* src0 = dst->src[0]; // q, fp32
+    ggml_tensor* src1 = dst->src[1]; // k, fp16
+    ggml_tensor* src2 = dst->src[2]; // v, fp16
+    ggml_tensor* src3 = dst->src[3]; // mask, fp16
+
+    float maxBias = 0.0f;
+    float scaleValue = 1.0f;
+    float logitSoftcap = 0.0f;
+    memcpy(&scaleValue,    (float*)dst->op_params + 0, sizeof(float));
+    memcpy(&maxBias,       (float*)dst->op_params + 1, sizeof(float));
+    memcpy(&logitSoftcap,  (float*)dst->op_params + 2, sizeof(float));
+
+    if(logitSoftcap == 0.0f){
+        size_t faElemSize = sizeof(uint16_t);
+        auto   faDataType = ACL_FLOAT16; //ACL_BF16;
+
+        aclTensor* acl_src0_f16_tensor = nullptr;
+        aclTensor* acl_src1_f16_tensor = nullptr;
+        aclTensor* acl_src2_f16_tensor = nullptr;
+        aclTensor* acl_dst_f16_tensor  = nullptr;
+
+        // Step 1: cast the src0 (Query) to fp16 if needed
+        ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
+        void* src0_f16_buffer = nullptr;
+
+        if(ggml_cann_type_mapping(src0->type) != faDataType){
+            aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
+            src0_f16_buffer = src0_f16_allocator.alloc(
+                                    ggml_nelements(src0) * faElemSize);
+
+            int64_t* src0_f16_ne = src0->ne;
+            size_t   src0_f16_nb[GGML_MAX_DIMS];
+            src0_f16_nb[0] = sizeof(uint16_t);
+            for(int i = 1; i < GGML_MAX_DIMS; ++i){
+                src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1];
+            }
+
+            acl_src0_f16_tensor = ggml_cann_create_tensor(
+                src0_f16_buffer, faDataType, faElemSize,
+                src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS
+            );
+            aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
+            ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
+        }else{
+            acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
+        }
+
+        // Step 2: create the acl tensors for src1 (Key), src2 (Value),
+        //         and the direct output from FusedInferAttention
+
+        acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
+        acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
+
+        ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
+        void* out_f16_buffer = out_f16_allocator.alloc(
+                                    ggml_nelements(dst) * faElemSize);
+
+        int64_t* out_f16_ne = src0->ne;
+        size_t out_f16_nb[GGML_MAX_DIMS];
+        out_f16_nb[0] = faElemSize;
+        for(int i = 1; i < GGML_MAX_DIMS; ++i){
+            out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
+        }
+
+        acl_dst_f16_tensor = ggml_cann_create_tensor(
+            out_f16_buffer, faDataType, faElemSize,
+            out_f16_ne, out_f16_nb, GGML_MAX_DIMS
+        );
+
+        // Step 3: create the PSEShift tensor if needed
+        //         this tensor is considered as mask (f16) in the llama.cpp
+
+        aclTensor* bcast_pse_tensor = nullptr;
+        int64_t bcast_pse_ne[GGML_MAX_DIMS];
+        size_t bcast_pse_nb[GGML_MAX_DIMS];
+        ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
+        void* bcast_pse_buffer = nullptr;
+
+        if(src3 != nullptr){
+            bcast_pse_buffer = bcast_pse_allocator.alloc(
+                            ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
+
+            if(src0->ne[1] > 1){
+                // Case 1: broadcast pse for prefill stage with multiple head
+                aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
+                bcast_pse_ne[0] = src3->ne[0];
+                bcast_pse_ne[1] = src3->ne[1];
+                bcast_pse_ne[2] = src0->ne[2];
+                bcast_pse_ne[3] = src3->ne[3];
+
+                bcast_pse_nb[0] = sizeof(uint16_t);
+                for(int i = 1; i < GGML_MAX_DIMS; ++i){
+                    bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
+                }
+
+                bcast_pse_tensor = ggml_cann_create_tensor(
+                    bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
+                    bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
+
+                int64_t repeats[] = {1, src0->ne[2], 1, 1};
+                aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
+
+                ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
+            }else{
+                // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
+                int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
+                size_t* trunc_pse_nb = src3->nb;
+
+                aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
+                    src3->data, ACL_FLOAT16, sizeof(uint16_t),
+                    trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
+
+                bcast_pse_ne[0] = src3->ne[0];
+                bcast_pse_ne[1] = src0->ne[1];
+                bcast_pse_ne[2] = src0->ne[2];
+                bcast_pse_ne[3] = src3->ne[3];
+
+                bcast_pse_nb[0] = sizeof(uint16_t);
+                for(int i = 1; i < GGML_MAX_DIMS; ++i){
+                    bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
+                }
+
+                bcast_pse_tensor = ggml_cann_create_tensor(
+                    bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
+                    bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
+
+                int64_t repeats[] = {1, src0->ne[2], 1, 1};
+                aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
+
+                ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
+            }
+
+            // Compute the slope if needed. Derived from ggml_cann_softmax().
+            if(maxBias != 0.0f){
+                // alibi
+                const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
+                const int64_t n_head = src0->ne[2];
+                const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
+                float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
+                float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
+                // init arange
+                ggml_cann_pool_alloc arange_allocator(ctx.pool(),
+                                                    ne2_ne3 * faElemSize);
+                void* tmp_arange_buffer = arange_allocator.get();
+
+                // arange1: [1, ..., n_heads_log2_floor+1)
+                float start = 1;
+                float stop = n_heads_log2_floor + 1;
+                float step = 1;
+                int64_t n_elements_arange = n_heads_log2_floor;
+
+                int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
+                size_t tmp_arange1_nb[] = {faElemSize};
+                aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
+                    tmp_arange_buffer, faDataType, faElemSize,
+                    tmp_arange1_ne, tmp_arange1_nb,
+                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+                aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
+
+                aclTensor* tmp_arange2_tensor = nullptr;
+                if (n_heads_log2_floor < ne2_ne3) {
+                    // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
+                    start = 1;
+                    stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
+                    step = 2;
+                    n_elements_arange = ne2_ne3 - n_heads_log2_floor;
+                    int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
+                    size_t tmp_arange2_nb[] = {faElemSize};
+
+                    aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
+                        (char*)tmp_arange_buffer +
+                            n_heads_log2_floor * faElemSize,
+                        faDataType, faElemSize,
+                        tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+                    aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
+                                n_elements_arange);
+                }
+
+                // init mk_base
+                ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
+                                                    ne2_ne3 * faElemSize);
+                void* tmp_mk_base_buffer = mk_base_allocator.get();
+                int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
+                size_t tmp_mk_base1_nb[] = {faElemSize};
+                aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
+                    tmp_mk_base_buffer, faDataType, faElemSize,
+                    tmp_mk_base1_ne, tmp_mk_base1_nb,
+                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+                aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
+
+                aclTensor* tmp_mk_base2_tensor = nullptr;
+                if (n_heads_log2_floor < ne2_ne3) {
+                    int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
+                    size_t tmp_mk_base2_nb[] = {faElemSize};
+                    aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
+                        (char*)tmp_mk_base_buffer +
+                            n_heads_log2_floor * faElemSize,
+                        faDataType, faElemSize,
+                        tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+                    aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
+                }
+
+                // init mk
+                int64_t tmp_mk_base_ne[] = {ne2_ne3};
+                size_t tmp_mk_base_nb[] = {faElemSize};
+                aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
+                    tmp_mk_base_buffer, faDataType, faElemSize,
+                    tmp_mk_base_ne, tmp_mk_base_nb,
+                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+                aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
+                    tmp_arange_buffer, faDataType, faElemSize,
+                    tmp_mk_base_ne, tmp_mk_base_nb,
+                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+                aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
+
+                // reshape mk
+                int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
+                size_t tmp_mk_nb[GGML_MAX_DIMS];
+                tmp_mk_nb[0] = faElemSize;
+                for (int i = 1; i < GGML_MAX_DIMS; i++) {
+                    tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
+                }
+                aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
+                    tmp_mk_base_buffer, faDataType, faElemSize,
+                    tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
+                    ACL_FORMAT_ND);
+                GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
+
+                ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
+                    tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
+                    tmp_arange_tensor, tmp_mk_tensor);
+            }
+        }
+
+        // Step 4: set the inputs for FusedInferAttention.
+        int kvTensorNum = 1;
+        aclTensor* acl_q_tensor = acl_src0_f16_tensor;
+        aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
+        aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
+        auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
+        auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
+
+        int64_t numHeads = src0->ne[2]; // N
+        int64_t numKeyValueHeads = src1->ne[2];
+        // double  scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
+        int64_t preTokens = 65535;
+        int64_t nextTokens = 65535;
+        char layout[5] = {'B', 'N', 'S', 'D', 0};
+        int64_t sparseMode = 0;
+        int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
+        int64_t blockSize = 0;
+        int64_t antiquantMode = 0;
+        bool softmaxLseFlag = false;
+        int64_t keyAntiquantMode = 0;
+        int64_t valueAntiquantMode = 0;
+
+        // Step 5: launch the FusedInferAttentionScoreV2 kernel.
+        // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
+
+        GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
+            acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
+            bcast_pse_tensor, nullptr, // pse, mask
+            nullptr, nullptr, // actSeqLen, actSeqLenkv
+            nullptr, nullptr, // deqScale1, quantScale1
+            nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
+            nullptr, nullptr, // antiquantScale, antiquantOffset
+            nullptr, // blockTable
+            nullptr, nullptr, // qPadSize, kvPadSize
+            nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
+            nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
+            nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
+            numHeads, scaleValue, // heads, scaleValue
+            preTokens, nextTokens, // preTokens, nextTokens
+            layout, // inputLayout
+            numKeyValueHeads, // numKVHeads
+            sparseMode, innerPrecise, // sparseMode, innerPrecise
+            blockSize, antiquantMode, // blockSize, antiquantMode
+            softmaxLseFlag, // softmaxLseFlag
+            keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
+            acl_dst_f16_tensor, // attentionOut
+            nullptr // softmaxLse
+        );
+
+        // Step 6: post-processing, permute and cast to f32
+
+        int64_t new_dim[] = {0, 2, 1, 3};
+        aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
+
+        if(ggml_cann_type_mapping(dst->type) != faDataType){
+            ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
+            perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
+            void* perm_out_f16_buffer = perm_out_f16_allocator.get();
+
+            int64_t* perm_out_f16_ne = dst->ne;
+            size_t  perm_out_f16_nb[GGML_MAX_DIMS];
+            perm_out_f16_nb[0] = faElemSize;
+            for(int i = 1; i < GGML_MAX_DIMS; ++i){
+                perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
+            }
+            aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
+                perm_out_f16_buffer, faDataType, faElemSize,
+                perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
+            aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
+            aclnn_cast(ctx,
+                acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
+            ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
+        }else{
+            // only need to permute
+            aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
+        }
+        ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
+                                         acl_src1_f16_tensor,
+                                         acl_src2_f16_tensor,
+                                         acl_dst_f16_tensor,
+                                         acl_dst_tensor);
+        if(src3 != nullptr){
+            ggml_cann_release_resources(ctx, bcast_pse_tensor);
+        }
+    }else{
+        GGML_ABORT("Function is not implemented.");
+    }
+}
old mode 100644 (file)
new mode 100755 (executable)
index 15993cc..80ce80b
@@ -714,6 +714,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
  */
 void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
 
+/**
+ * @brief   Performs the Flash Attention extended operator using the CANN backend.
+ *
+ * @details This function implements the memory-efficient Flash Attention algorithm
+ *          for computing scaled dot-product attention with hardware acceleration.
+ *          The result is stored in the destination tensor `dst`.
+ *
+ *          This operation is accelerated using the CANN backend to improve runtime performance.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the result will be stored.
+ *            dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
+ */
+void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
 /*
  * @brief A generic wrapper for ACL resources with custom deleter support.
  */
old mode 100644 (file)
new mode 100755 (executable)
old mode 100644 (file)
new mode 100755 (executable)
index 605b6a7..c0ea260
@@ -36,6 +36,7 @@
 #include "ggml-backend-impl.h"
 #include "ggml-cann/aclnn_ops.h"
 #include "ggml-cann/common.h"
+#include "ggml.h"
 
 #define GGML_COMMON_DECL_C
 
@@ -1748,6 +1749,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
         case GGML_OP_COUNT_EQUAL:
             ggml_cann_count_equal(ctx, dst);
             break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            ggml_cann_flash_attn_ext(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -2177,6 +2181,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_COUNT_EQUAL:
             return true;
+        case GGML_OP_FLASH_ATTN_EXT:{
+            // derived from [ggml-cuda.cu]
+            if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
+                return false;
+            }
+            if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){
+                return false;
+            }
+            if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
+                return false;
+            }
+            if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
+                // different head sizes of K and V are not supported yet
+                return false;
+            }
+            if (op->src[0]->ne[0] == 192) {
+                return false;
+            }
+            if (op->src[0]->ne[0] == 576) {
+                // DeepSeek MLA
+                return false;
+            }
+            if (op->src[0]->ne[3] != 1) {
+                return false;
+            }
+            float logitSoftcap = 0.0f;
+            memcpy(&logitSoftcap,  (float*)op->op_params + 2, sizeof(float));
+            if(logitSoftcap != 0.0f) {
+                return false;
+            }
+            return true;
+        }
         default:
             return false;
     }