]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CANN: refactor mask handling and improve performance in FA (llama/15561)
authorChenguang Li <redacted>
Wed, 27 Aug 2025 09:21:41 +0000 (17:21 +0800)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:54:06 +0000 (12:54 +0300)
* CANN(flash-attn): refactor mask handling and improve performance

1. Refactored the mask computation in Flash Attention, unified the logic without separating prefill and decode.
2. Optimized performance in non-alibi scenarios by reducing one repeat operation.
3. Updated operator management to explicitly mark unsupported cases on 310P devices and when dim is not divisible by 16.

Signed-off-by: noemotiovon <redacted>
* [CANN]: fix review

Signed-off-by: noemotiovon <redacted>
* [CANN]: Optimization FA BNSD to BSND

Signed-off-by: noemotiovon <redacted>
---------

Signed-off-by: noemotiovon <redacted>
src/ggml-cann/aclnn_ops.cpp
src/ggml-cann/ggml-cann.cpp

index bc33b99d96ea686ed1f42430cdfe7f4cd92a6a68..c42871c575822817c9d18fd2314352525e64347f 100755 (executable)
@@ -1427,17 +1427,17 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
 static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
     float m, int64_t size, float start, float stop, float step){
     int64_t ne[] = {size};
-    size_t nb[] = {sizeof(float)};
+    size_t nb[] = {sizeof(uint16_t)};
 
-    ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
+    ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(uint16_t));
     void* arange_buffer = arange_allocator.get();
 
     aclTensor* arange_tensor = ggml_cann_create_tensor(
-        arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
+        arange_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
     aclnn_arange(ctx, arange_tensor, start, stop, step, size);
 
     aclTensor* slope_tensor = ggml_cann_create_tensor(
-        slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
+        slope_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
 
     aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
 
@@ -3180,11 +3180,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
 void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
 
-    ggml_tensor* src0 = dst->src[0]; // q, fp32
-    ggml_tensor* src1 = dst->src[1]; // k, fp16
-    ggml_tensor* src2 = dst->src[2]; // v, fp16
+    ggml_tensor* src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)
+    ggml_tensor* src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
+    ggml_tensor* src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
     ggml_tensor* src3 = dst->src[3]; // mask, fp16
 
+    // B, N, S, D (uncont) -> B, S, N, D (cont)
+    int64_t src0_bsnd_ne[GGML_MAX_DIMS];
+    memcpy(src0_bsnd_ne, src0->ne, GGML_MAX_DIMS * sizeof(int64_t));
+    size_t src0_bsnd_nb[GGML_MAX_DIMS];
+    memcpy(src0_bsnd_nb, src0->nb, GGML_MAX_DIMS * sizeof(size_t));
+    int64_t src1_bsnd_ne[GGML_MAX_DIMS];
+    memcpy(src1_bsnd_ne, src1->ne, GGML_MAX_DIMS * sizeof(int64_t));
+    size_t src1_bsnd_nb[GGML_MAX_DIMS];
+    memcpy(src1_bsnd_nb, src1->nb, GGML_MAX_DIMS * sizeof(size_t));
+    int64_t src2_bsnd_ne[GGML_MAX_DIMS];
+    memcpy(src2_bsnd_ne, src2->ne, GGML_MAX_DIMS * sizeof(int64_t));
+    size_t src2_bsnd_nb[GGML_MAX_DIMS];
+    memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t));
+
+    auto transpose12 = [](int64_t* ne, size_t* nb) {
+        int64_t ne_tmp = ne[1];
+        size_t  nb_tmp = nb[1];
+        ne[1] = ne[2];
+        nb[1] = nb[2];
+        ne[2] = ne_tmp;
+        nb[2] = nb_tmp;
+    };
+
+    transpose12(src0_bsnd_ne, src0_bsnd_nb);
+    transpose12(src1_bsnd_ne, src1_bsnd_nb);
+    transpose12(src2_bsnd_ne, src2_bsnd_nb);
+
     float maxBias = 0.0f;
     float scaleValue = 1.0f;
     float logitSoftcap = 0.0f;
@@ -3206,11 +3233,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
         void* src0_f16_buffer = nullptr;
 
         if(ggml_cann_type_mapping(src0->type) != faDataType){
-            aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
+            aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
+                src0_bsnd_nb, GGML_MAX_DIMS);
             src0_f16_buffer = src0_f16_allocator.alloc(
                                     ggml_nelements(src0) * faElemSize);
 
-            int64_t* src0_f16_ne = src0->ne;
+            int64_t* src0_f16_ne = src0_bsnd_ne;
             size_t   src0_f16_nb[GGML_MAX_DIMS];
             src0_f16_nb[0] = sizeof(uint16_t);
             for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -3224,20 +3252,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
             aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
             ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
         }else{
-            acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
+            acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
+                src0_bsnd_nb, GGML_MAX_DIMS);
         }
 
         // Step 2: create the acl tensors for src1 (Key), src2 (Value),
         //         and the direct output from FusedInferAttention
 
-        acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
-        acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
+        acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne,
+            src1_bsnd_nb, GGML_MAX_DIMS);
+        acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
+            src2_bsnd_nb, GGML_MAX_DIMS);
 
         ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
         void* out_f16_buffer = out_f16_allocator.alloc(
                                     ggml_nelements(dst) * faElemSize);
 
-        int64_t* out_f16_ne = src0->ne;
+        int64_t* out_f16_ne = src0_bsnd_ne;
         size_t out_f16_nb[GGML_MAX_DIMS];
         out_f16_nb[0] = faElemSize;
         for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -3251,88 +3282,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
 
         // Step 3: create the PSEShift tensor if needed
         //         this tensor is considered as mask (f16) in the llama.cpp
-
         aclTensor* bcast_pse_tensor = nullptr;
-        int64_t bcast_pse_ne[GGML_MAX_DIMS];
-        size_t bcast_pse_nb[GGML_MAX_DIMS];
         ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
-        void* bcast_pse_buffer = nullptr;
-
         if(src3 != nullptr){
-            bcast_pse_buffer = bcast_pse_allocator.alloc(
-                            ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
-
-            if(src0->ne[1] > 1){
-                // Case 1: broadcast pse for prefill stage with multiple head
-                aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
-                bcast_pse_ne[0] = src3->ne[0];
-                bcast_pse_ne[1] = src3->ne[1];
-                bcast_pse_ne[2] = src0->ne[2];
-                bcast_pse_ne[3] = src3->ne[3];
+            // Construct the truncated pse tensor (common for prefill/decode)
+            int64_t trunc_pse_ne[GGML_MAX_DIMS] = {
+                src3->ne[0],        // D
+                src0->ne[1],        // S (number of Q tokens)
+                src3->ne[2],        // mask N
+                src3->ne[3]         // B
+            };
+            size_t* trunc_pse_nb = src3->nb;
+
+            aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
+                src3->data, ACL_FLOAT16, sizeof(uint16_t),
+                trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
+            );
 
+            int64_t bcast_pse_ne[GGML_MAX_DIMS];
+            size_t bcast_pse_nb[GGML_MAX_DIMS];
+            bcast_pse_ne[0] = src3->ne[0];      // D
+            bcast_pse_ne[1] = src0->ne[1];      // S
+            bcast_pse_ne[2] = src0->ne[2];      // N (num_heads)
+            bcast_pse_ne[3] = src3->ne[3];      // B
+            if (maxBias == 0.0f) {
+                // When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2)
+                // Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
                 bcast_pse_nb[0] = sizeof(uint16_t);
-                for(int i = 1; i < GGML_MAX_DIMS; ++i){
-                    bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
-                }
+                bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0];
+                bcast_pse_nb[2] = 0;                // <---- the head dimension shares the same data
+                bcast_pse_nb[3] = src3->nb[3];
 
                 bcast_pse_tensor = ggml_cann_create_tensor(
-                    bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
-                    bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
-
-                int64_t repeats[] = {1, src0->ne[2], 1, 1};
-                aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
-
-                ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
-            }else{
-                // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
-                int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
-                size_t* trunc_pse_nb = src3->nb;
-
-                aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
                     src3->data, ACL_FLOAT16, sizeof(uint16_t),
-                    trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
-
-                bcast_pse_ne[0] = src3->ne[0];
-                bcast_pse_ne[1] = src0->ne[1];
-                bcast_pse_ne[2] = src0->ne[2];
-                bcast_pse_ne[3] = src3->ne[3];
+                    bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
+                );
 
+                ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
+            } else {
                 bcast_pse_nb[0] = sizeof(uint16_t);
-                for(int i = 1; i < GGML_MAX_DIMS; ++i){
+                for (int i = 1; i < GGML_MAX_DIMS; i++) {
                     bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
                 }
 
+                void* bcast_pse_buffer = bcast_pse_allocator.alloc(
+                    ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)
+                );
+
                 bcast_pse_tensor = ggml_cann_create_tensor(
                     bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
-                    bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
+                    bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
+                );
 
                 int64_t repeats[] = {1, src0->ne[2], 1, 1};
                 aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
 
-                ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
-            }
-
-            // Compute the slope if needed. Derived from ggml_cann_softmax().
-            if(maxBias != 0.0f){
                 // alibi
+                // Compute the slope if needed. Derived from ggml_cann_softmax().
                 const int64_t n_heads = src0->ne[2];
-                ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
+                ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));
                 void* slope_buffer = slope_allocator.get();
                 aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
 
                 int64_t slope_ne[] = {1, 1, n_heads, 1};
                 size_t slope_nb[GGML_MAX_DIMS];
-                slope_nb[0] = sizeof(float);
+                slope_nb[0] = sizeof(uint16_t);
                 for(int i = 1;i<GGML_MAX_DIMS;i++) {
                     slope_nb[i] = slope_nb[i-1] * slope_ne[0];
                 }
 
                 aclTensor* slope_tensor = ggml_cann_create_tensor(
-                    slope_buffer, ACL_FLOAT, sizeof(float),
+                    slope_buffer, ACL_FLOAT16, sizeof(uint16_t),
                     slope_ne, slope_nb, GGML_MAX_DIMS);
                 GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
 
-                ggml_cann_release_resources(ctx, slope_tensor);
+                ggml_cann_release_resources(ctx, slope_tensor, acl_mask_f16_trunc_tensor);
             }
         }
 
@@ -3349,7 +3373,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
         // double  scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
         int64_t preTokens = 65535;
         int64_t nextTokens = 65535;
-        char layout[5] = {'B', 'N', 'S', 'D', 0};
+        char layout[5] = {'B', 'S', 'N', 'D', 0};
         int64_t sparseMode = 0;
         int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
         int64_t blockSize = 0;
@@ -3386,32 +3410,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
         );
 
         // Step 6: post-processing, permute and cast to f32
-
-        int64_t new_dim[] = {0, 2, 1, 3};
         aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
-
-        if(ggml_cann_type_mapping(dst->type) != faDataType){
-            ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
-            perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
-            void* perm_out_f16_buffer = perm_out_f16_allocator.get();
-
-            int64_t* perm_out_f16_ne = dst->ne;
-            size_t  perm_out_f16_nb[GGML_MAX_DIMS];
-            perm_out_f16_nb[0] = faElemSize;
-            for(int i = 1; i < GGML_MAX_DIMS; ++i){
-                perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
-            }
-            aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
-                perm_out_f16_buffer, faDataType, faElemSize,
-                perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
-            aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
-            aclnn_cast(ctx,
-                acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
-            ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
-        }else{
-            // only need to permute
-            aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
-        }
+        // TODO: when dst is fp16, don't need cast
+        aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
         ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
                                          acl_src1_f16_tensor,
                                          acl_src2_f16_tensor,
index cb8af42ebf95650629fa4e9bef22307443dc1946..81215425618a35466430ce24d542db2a15c788e7 100755 (executable)
@@ -2336,7 +2336,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                 case GGML_TYPE_Q8_0:
                 case GGML_TYPE_Q4_0:
 #ifdef ASCEND_310P
-                    // Q4 && Q8 per group is not suppor on 310p device
+                    // Q4 && Q8 per group is not support on 310p device
                     return false;
 #endif
                     // only support contiguous for quantized types.
@@ -2354,7 +2354,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                 case GGML_TYPE_Q8_0:
                 case GGML_TYPE_Q4_0:
 #ifdef ASCEND_310P
-                    // Q4 && Q8 per group is not suppor on 310p device
+                    // Q4 && Q8 per group is not support on 310p device
                     return false;
 #endif
                     // only support contiguous for quantized types.
@@ -2505,6 +2505,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
             }
             return true;
         case GGML_OP_FLASH_ATTN_EXT:{
+#ifdef ASCEND_310P
+            // FA not support on 310p device
+            return false;
+#endif
             // derived from [ggml-cuda.cu]
             if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
                 return false;
@@ -2530,6 +2534,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                 // DeepSeek MLA
                 return false;
             }
+            if (op->src[0]->ne[0] % 16 != 0) {
+                // TODO: padding to support
+                return false;
+            }
             float logitSoftcap = 0.0f;
             memcpy(&logitSoftcap,  (float*)op->op_params + 2, sizeof(float));
             if(logitSoftcap != 0.0f) {