]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CANN: support flash attention for head dim not multiple of 16, fix ALiBi slope offset...
authorChenguang Li <redacted>
Thu, 19 Mar 2026 03:02:42 +0000 (11:02 +0800)
committerGeorgi Gerganov <redacted>
Sat, 28 Mar 2026 11:39:09 +0000 (13:39 +0200)
- Allow FLASH_ATTN_EXT when head dimension D is not a multiple of 16 by
  padding Q/K/V to D_padded = GGML_PAD(D, 16), running FusedInferAttentionScoreV2,
  then slicing the output back to D (ggml-cann.cpp + aclnn_ops.cpp).
- Fix aclnn_get_slope second-part offset: use ggml_type_size(dtype) instead of
  sizeof(float) so ALiBi slopes are correct when dtype is F16 (e.g. GQA with
  48 heads); fixes buffer overflow and large numerical errors in those cases.

src/ggml-cann/aclnn_ops.cpp
src/ggml-cann/ggml-cann.cpp

index fc7c3e3b724c51031ea071adf666812a973f6111..4b7aab1e72d50097f50052168038beb6b7181ccc 100644 (file)
@@ -1544,8 +1544,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx,
         end   = 2 * ((n_head - 1) - n_head_log2) + 1;
         step  = 2;
         count = n_head - n_head_log2;
-        aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step,
-                              dtype);
+        aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1,
+                              step, dtype);
     }
 }
 
@@ -3599,6 +3599,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
         acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS);
         acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS);
 
+        // Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16
+        //           (required by FusedInferAttentionScoreV2)
+        const int64_t D         = src0->ne[0];
+        const int64_t D_padded  = GGML_PAD(D, 16);
+        const bool needs_padding = (D != D_padded);
+
+        ggml_cann_pool_alloc q_pad_allocator(ctx.pool());
+        ggml_cann_pool_alloc k_pad_allocator(ctx.pool());
+        ggml_cann_pool_alloc v_pad_allocator(ctx.pool());
+
+        if (needs_padding) {
+            int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 };
+
+            auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne,
+                                     ggml_cann_pool_alloc & allocator) {
+                int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] };
+                size_t  pad_nb[GGML_MAX_DIMS];
+                pad_nb[0] = faElemSize;
+                for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+                    pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1];
+                }
+                int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3];
+                void *  buffer    = allocator.alloc(nelements * faElemSize);
+                acl_tensor_ptr padded =
+                    ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS);
+                aclnn_pad(ctx, tensor.get(), padded.get(), paddings);
+                tensor = std::move(padded);
+            };
+
+            pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator);
+            pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator);
+            pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator);
+
+            src0_bsnd_ne[0] = D_padded;
+            src1_bsnd_ne[0] = D_padded;
+            src2_bsnd_ne[0] = D_padded;
+        }
+
         // Step 3: create the PSEShift tensor if needed
         //         this tensor is considered as mask (f16) in the llama.cpp
         acl_tensor_ptr       bcast_pse_tensor;
@@ -3688,17 +3726,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
 
         GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
         acl_tensor_ptr       fa_dst_tensor;
-        acl_tensor_ptr       acl_dst_tensor;
         ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
-        if (dst->type == GGML_TYPE_F32) {
-            void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
-
+        if (dst->type == GGML_TYPE_F32 || needs_padding) {
             int64_t * out_f16_ne = src0_bsnd_ne;
             size_t    out_f16_nb[GGML_MAX_DIMS];
             out_f16_nb[0] = faElemSize;
             for (int i = 1; i < GGML_MAX_DIMS; ++i) {
                 out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
             }
+            int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3];
+            void *  out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize);
 
             fa_dst_tensor =
                 ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS);
@@ -3730,8 +3767,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
                                 nullptr                                // softmaxLse
         );
 
-        if (dst->type == GGML_TYPE_F32) {
-            // Step 6: post-processing, permute and cast to f32
+        // Step 6: post-processing — slice padded output and/or cast to f32
+        if (needs_padding) {
+            ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool());
+
+            if (dst->type == GGML_TYPE_F32) {
+                int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] };
+                size_t  sliced_nb[GGML_MAX_DIMS];
+                sliced_nb[0] = faElemSize;
+                for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+                    sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1];
+                }
+                int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3];
+                void *  sliced_buffer    = sliced_f16_allocator.alloc(sliced_nelements * faElemSize);
+                acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize,
+                                                                           sliced_ne, sliced_nb, GGML_MAX_DIMS);
+
+                GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
+                                        (int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get());
+
+                acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
+                aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
+            } else {
+                acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
+                GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(),
+                                        (int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get());
+            }
+        } else if (dst->type == GGML_TYPE_F32) {
             acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst);
             aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type));
         }
index 3f3de9f0bcba5808aadb7f67099744491a7b9945..a682746bb428f8c532c3a4b27526b0403040d72d 100644 (file)
@@ -2503,10 +2503,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
                     // different head sizes of K and V are not supported yet
                     return false;
                 }
-                if (op->src[0]->ne[0] % 16 != 0) {
-                    // TODO: padding to support
-                    return false;
-                }
                 float logitSoftcap = 0.0f;
                 memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float));
                 if (logitSoftcap != 0.0f) {