]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CANN: Add broadcast for softmax and FA (llama/15208)
authorhipudding <redacted>
Mon, 11 Aug 2025 14:50:31 +0000 (22:50 +0800)
committerGeorgi Gerganov <redacted>
Thu, 14 Aug 2025 11:17:28 +0000 (14:17 +0300)
* refactor softmax

* fix fa

* fix mask shape

* format

* add comments

* Remove whitespace

src/ggml-cann/aclnn_ops.cpp
src/ggml-cann/ggml-cann.cpp

index 07d6b8b67d47c4f916ac54f467fcda3d5e3d013b..0b409ce87d2abf80b35c3eb629ce53c8210d4b58 100755 (executable)
@@ -812,7 +812,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
             ggml_cann_release_resources(ctx, src_trans_tensor);
             return;
         } else {
-            GGML_ABORT("Unsupport dst is not tontiguous.");
+            GGML_ABORT("Unsupport dst is not contiguous.");
         }
     }
     ggml_cann_release_resources(ctx, acl_src, acl_dst);
@@ -1330,160 +1330,196 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
 }
 
 /**
- * @brief   Applies the Alibi (Attention with Linear Biases) mechanism to the
- * @details This function implements the Alibi mechanism, which introduces
- *          learnable biases into the attention scores to simulate relative
- *          position encoding without the need for explicit positional
- *          embeddings.
- *
- * @param ctx          The backend CANN context for executing operations.
- * @param acl_src      The source tensor representing the query or key.
- * @param acl_position The position tensor containing relative positions.
- * @param acl_dst      The destination tensor where the result will be stored.
- * @param n_head       The number of attention heads.
- * @param src_ne       The dimensions of the source tensor.
- * @param src_nb0      The byte size of the first dimension of the source
- tensor.
- * @param max_bias     The maximum bias value used in the Alibi mechanism.
- * @param dst          The destination tensor object for additional metadata.
- *
- * The function performs the following steps:
- * 1. Calculates the logarithm floor of the number of heads to determine the
-      base for bias calculation.
- * 2. Initializes arrays with arithmetic sequences and fills them with bias
-      values.
- * 3. Computes the bias tensor based on the calculated biases and arithmetic
-      sequences.
- * 4. Reshapes the bias tensor to match the dimensions of the input tensors.
- * 5. Multiplies the position tensor by the bias tensor.
- * 6. Adds the result of the multiplication to the source tensor to produce the
-      final output.
+ * @brief Generate a range of values and apply a scalar base exponentiation.
+ *
+ * This function creates an evenly spaced sequence from `start` to `stop` (exclusive),
+ * with step size `step`, stores it in a temporary buffer, and then computes:
+ *
+ * @f[
+ * slope[i] = m^{\left( start + i \cdot step \right)}, \quad 0 \le i < size
+ * @f]
+ *
+ * The results are written to the provided @p slope_buffer.
+ *
+ * @param ctx           CANN backend context for memory allocation and operator execution.
+ * @param slope_buffer  Pointer to the output buffer (float array) for the computed slope values.
+ * @param m             Scalar base for the exponentiation.
+ * @param size          Number of elements in the generated sequence.
+ * @param start         Starting exponent offset.
+ * @param stop          Stopping exponent offset (exclusive).
+ * @param step          Step size for the exponent increment.
  */
-static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
-                        aclTensor* acl_position, aclTensor* acl_dst,
-                        const int n_head, int64_t* src_ne, const size_t src_nb0,
-                        float max_bias, ggml_tensor* dst) {
-    const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
-    GGML_ASSERT(src_nb0 == sizeof(float));
-    GGML_ASSERT(n_head == src_ne[2]);
-
-    const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
-
-    float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
-    float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
-    // init arange
-    ggml_cann_pool_alloc arange_allocator(ctx.pool(),
-                                          ne2_ne3 * ggml_type_size(dst->type));
-    void* tmp_arange_buffer = arange_allocator.get();
+static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
+    float m, int64_t size, float start, float stop, float step){
+    int64_t ne[] = {size};
+    size_t nb[] = {sizeof(float)};
 
-    // arange1: [1, ..., n_heads_log2_floor+1)
-    float start = 1;
-    float stop = n_heads_log2_floor + 1;
-    float step = 1;
-    int64_t n_elements_arange = n_heads_log2_floor;
+    ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
+    void* arange_buffer = arange_allocator.get();
 
-    int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
-    size_t tmp_arange1_nb[] = {sizeof(dst->type)};
-    aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
-        tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb,
-        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+    aclTensor* arange_tensor = ggml_cann_create_tensor(
+        arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
+    aclnn_arange(ctx, arange_tensor, start, stop, step, size);
 
-    aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
-
-    aclTensor* tmp_arange2_tensor = nullptr;
-    if (n_heads_log2_floor < ne2_ne3) {
-        // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
-        start = 1;
-        stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
-        step = 2;
-        n_elements_arange = ne2_ne3 - n_heads_log2_floor;
-        int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
-        size_t tmp_arange2_nb[] = {sizeof(dst->type)};
-
-        aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
-            (char*)tmp_arange_buffer +
-                n_heads_log2_floor * ggml_type_size(dst->type),
-            ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
-            tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-        aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
-                     n_elements_arange);
-    }
+    aclTensor* slope_tensor = ggml_cann_create_tensor(
+        slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
 
-    // init mk_base
-    ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
-                                           ne2_ne3 * ggml_type_size(dst->type));
-    void* tmp_mk_base_buffer = mk_base_allocator.get();
-    int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
-    size_t tmp_mk_base1_nb[] = {sizeof(dst->type)};
-    aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
-        tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb,
-        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+    aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
 
-    aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
-
-    aclTensor* tmp_mk_base2_tensor = nullptr;
-    if (n_heads_log2_floor < ne2_ne3) {
-        int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
-        size_t tmp_mk_base2_nb[] = {sizeof(dst->type)};
-        aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
-            (char*)tmp_mk_base_buffer +
-                n_heads_log2_floor * ggml_type_size(dst->type),
-            ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
-            tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-        aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
-    }
+    GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor);
+    ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor);
+}
 
-    // init mk
-    int64_t tmp_mk_base_ne[] = {ne2_ne3};
-    size_t tmp_mk_base_nb[] = {sizeof(dst->type)};
-    aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
-        tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
-        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-    aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
-        tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
-        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-    aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
+/**
+ * @brief Compute slope values for multiple attention heads based on ALiBi bias parameters.
+ *
+ * This function generates slope values for each attention head according to the ALiBi
+ * (Attention with Linear Biases) method. It splits the computation into two ranges depending
+ * on whether the head index is less than @p n_head_log2 or not, and uses different base values
+ * (`m0` and `m1`) for the exponentiation.
+ *
+ * @f[
+ * slope[h] =
+ * \begin{cases}
+ * m_0^{(h + 1)}, & h < n\_head\_log2 \\
+ * m_1^{\left( 2 \cdot (h - n\_head\_log2) + 1 \right)}, & h \geq n\_head\_log2
+ * \end{cases}
+ * \quad , \quad \text{if } max\_bias > 0
+ * @f]
+ *
+ * If @p max_bias <= 0, all slope values are set to 1.0.
+ *
+ * @param ctx           CANN backend context for memory allocation and operator execution.
+ * @param n_head        Total number of attention heads.
+ * @param slope_buffer  Pointer to the output buffer (float array) for storing slopes.
+ * @param max_bias      Maximum bias value for slope computation.
+ *
+*/
+static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
+    void* slope_buffer, float max_bias) {
+    const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    float m0 = powf(2.0f, -(max_bias) / n_head_log2);
+    float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    // const float slope = (max_bias > 0.0f) ?
+    //                          h < n_head_log2 ?
+    //                              powf(m0, h + 1) :
+    //                              powf(m1, 2*(h - n_head_log2) + 1) :
+    //                          1.0f;
+    // arange1
+    float start = 0 + 1;
+    float end   = (n_head_log2 - 1) + 1;
+    float step  = 1;
+    float count = n_head_log2;
+    // end needs to be +1 because aclnn uses a left-closed, right-open interval.
+    aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step);
+    if (n_head_log2 < n_head) {
+        // arange2
+        start = 2 * (n_head_log2 - n_head_log2) + 1;
+        end   = 2 * ((n_head - 1) - n_head_log2) + 1;
+        step  = 2;
+        count = n_head - n_head_log2;
+        aclnn_get_slope_inner(
+            ctx, (char *) slope_buffer + n_head_log2 * sizeof(float),
+            m1, count, start, end + 1, step);
+    }
+}
 
-    // reshape mk
-    int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
-    size_t tmp_mk_nb[GGML_MAX_DIMS];
-    tmp_mk_nb[0] = ggml_type_size(dst->type);
-    for (int i = 1; i < GGML_MAX_DIMS; i++) {
-        tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
+/**
+ * @brief Add ALiBi (Attention with Linear Biases) positional biases to the attention mask.
+ *
+ * This function computes the ALiBi slopes for each attention head (if max_bias > 0),
+ * multiplies them with the attention mask to produce bias tensors, and adds these biases
+ * to the destination tensor (@p dst).
+ *
+ * The function performs necessary broadcasting of the mask and slope tensors to match
+ * the shape of the destination tensor, then applies element-wise multiplication and addition
+ * using CANN operators.
+ *
+ * @param ctx         CANN backend context for memory management and operator execution.
+ * @param mask        Input attention mask tensor, assumed to be contiguous.
+ * @param dst         Destination tensor to which ALiBi biases will be added.
+ * @param dst_ptr     Pointer to the memory of the destination tensor.
+ * @param max_bias    Maximum bias value controlling the slope scaling.
+ *
+ * @note
+ * - Write data into dst_ptr using only the shape information of the dst tensor.
+ * - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting.
+ */
+static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask,
+    ggml_tensor* dst, void* dst_ptr, float max_bias) {
+    void* slope_buffer = nullptr;
+    void* bias_buffer = nullptr;
+
+    if (max_bias > 0.0f) {
+        int64_t n_heads = dst->ne[2];
+        ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
+        slope_buffer = slope_allocator.get();
+        ggml_cann_pool_alloc bias_allocator(
+                    ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));
+        bias_buffer = bias_allocator.get();
+        aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias);
     }
-    aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
-        tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
-        ACL_FORMAT_ND);
 
-    // acl_position * mk
-    int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
-    size_t tmp_output_nb[GGML_MAX_DIMS];
-    tmp_output_nb[0] = ggml_type_size(dst->type);
-    for (int i = 1; i < GGML_MAX_DIMS; i++) {
-        tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1];
+    // broadcast for mask, slop and dst;
+    int64_t nr2 = dst->ne[2] / mask->ne[2];
+    int64_t nr3 = dst->ne[3] / mask->ne[3];
+
+    // broadcast the mask across rows
+    int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 };
+    size_t  mask_nb[] = {
+        mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2],
+        mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3]
+    };
+
+    int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 };
+    size_t  dst_nb[] = {
+        dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2],
+        dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3]
+    };
+
+    // slope is a 1 dim tensor, slope.ne2 == dst.ne2
+    int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 };
+    size_t  slope_nb[GGML_MAX_DIMS + 2];
+    slope_nb[0] = sizeof(float);
+    for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
+        slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1];
     }
-    ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst));
-    void* tmp_output_buffer = output_allocator.get();
-    aclTensor* tmp_output_tensor = ggml_cann_create_tensor(
-        tmp_output_buffer, ggml_cann_type_mapping(dst->type),
-        ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS,
-        ACL_FORMAT_ND);
-    aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor);
 
-    // add
-    aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
-    ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
-        tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
-        tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor);
+    aclTensor* acl_slope = ggml_cann_create_tensor(
+                            slope_buffer, ACL_FLOAT, sizeof(float),
+                            slope_ne, slope_nb, GGML_MAX_DIMS + 2);
+    aclTensor* acl_mask = ggml_cann_create_tensor(
+                            mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2);
+
+    // write data into dst_ptr using only the shape information of the dst tensor.
+    aclTensor* acl_dst  = ggml_cann_create_tensor(
+                            dst_ptr, ggml_cann_type_mapping(dst->type),
+                            ggml_type_size(dst->type), dst_ne, dst_nb,
+                            GGML_MAX_DIMS + 2);
+
+    if (max_bias > 0.0f) {
+        int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 };
+        size_t  bias_nb[GGML_MAX_DIMS + 2];
+        bias_nb[0] = sizeof(float);
+        for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
+            bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1];
+        }
+        aclTensor* bias_tensor = ggml_cann_create_tensor(
+                                    bias_buffer, ACL_FLOAT, sizeof(float),
+                                    bias_ne, bias_nb, GGML_MAX_DIMS + 2);
+
+        aclnn_mul(ctx, acl_slope, acl_mask, bias_tensor);
+        aclnn_add(ctx, acl_dst, bias_tensor);
+        ggml_cann_release_resources(ctx, bias_tensor);
+    } else {
+        aclnn_add(ctx, acl_dst, acl_mask);
+    }
+    ggml_cann_release_resources(ctx, acl_slope, acl_mask, acl_dst);
 }
 
-void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     ggml_cann_dup(ctx, dst);
 }
 
@@ -1501,118 +1537,41 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  * @param acl_dst The destination tensor where the softmax results will be
  * stored.
  */
-static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
-                          int64_t dim, aclTensor* acl_dst) {
+static void aclnn_softmax(ggml_backend_cann_context & ctx,
+    aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) {
     GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst);
 }
 
-void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     ggml_tensor* src0 = dst->src[0];
     ggml_tensor* src1 = dst->src[1];  // mask
 
     aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
-    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+    aclTensor* acl_dst  = ggml_cann_create_tensor(dst);
 
-    float scale = 1.0f;
+    float scale    = 1.0f;
     float max_bias = 0.0f;
 
-    memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));
+    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
 
     // input mul scale
     aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
+    ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0));
+    void* src_tensor_buffer = src_tensor_allocator.get();
+    aclTensor* softmax_tensor = ggml_cann_create_tensor(
+        src_tensor_buffer, ggml_cann_type_mapping(src0->type),
+        ggml_element_size(src0), src0->ne, src0->nb,GGML_MAX_DIMS);
 
-    size_t n_bytes = ggml_nbytes(src0);
-    ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes);
-    void* input_mul_scale_buffer = mul_scale_allocator.get();
-    aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor(
-        input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne,
-        src0->nb, GGML_MAX_DIMS);
-
-    bool inplace = false;
-    aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace);
+    aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false);
 
     // mask
-    aclTensor* acl_src1_fp32_tensor = nullptr;
-    aclTensor* tmp_mask_tensor = nullptr;
-    ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool());
     if (src1) {
-        const bool use_f16 = src1->type == GGML_TYPE_F16;
-        if (use_f16) {
-            // cast to fp32
-            size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
-            size_t src1_fp32_nb[GGML_MAX_DIMS];
-            src1_fp32_nb[0] = sizeof(float_t);
-            for (int i = 1; i < GGML_MAX_DIMS; i++) {
-                src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
-            }
-            src1_fp32_allocator.alloc(n_bytes);
-            void* src1_fp32_buffer = src1_fp32_allocator.get();
-            acl_src1_fp32_tensor = ggml_cann_create_tensor(
-                src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne,
-                src1_fp32_nb, GGML_MAX_DIMS);
-            aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
-            aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
-            ggml_cann_release_resources(ctx, acl_src1);
-        } else {
-            acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
-        }
-
-        // broadcast the mask across rows, only use ne11 of ne01 in mask
-        if (src1->ne[1] != src0->ne[1]) {
-            // mask shape: [1,1,ne11,ne10]
-            int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1};
-            size_t tmp_mask_nb[GGML_MAX_DIMS];
-            tmp_mask_nb[0] = sizeof(float_t);
-            for (int i = 1; i < GGML_MAX_DIMS; i++) {
-                tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1];
-            }
-            tmp_mask_tensor = ggml_cann_create_tensor(
-                src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb,
-                GGML_MAX_DIMS, ACL_FORMAT_ND);
-        }
-
-        // alibi
-        const int n_head = src0->ne[2];
-        const size_t src_nb0 = src0->nb[0];
-
-        n_bytes = ggml_nbytes(dst);
-        ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes);
-        void* output_buffer = output_allocator.get();
-        aclTensor* alibi_output_tensor = ggml_cann_create_tensor(
-            output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne,
-            dst->nb, GGML_MAX_DIMS);
-        if (max_bias <= 0.0f) {
-            // slope = 1.0
-            if (tmp_mask_tensor) {
-                aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor,
-                          alibi_output_tensor);
-            } else {
-                aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
-                          alibi_output_tensor);
-            }
-        } else {
-            // slope != 1.0
-            if (tmp_mask_tensor) {
-                aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor,
-                            alibi_output_tensor, n_head, src0->ne, src_nb0,
-                            max_bias, dst);
-            } else {
-                aclnn_alibi(ctx, acl_input_mul_scale_tensor,
-                            acl_src1_fp32_tensor, alibi_output_tensor, n_head,
-                            src0->ne, src_nb0, max_bias, dst);
-            }
-        }
-
-        // softmax
-        aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
-        ggml_cann_release_resources(ctx, alibi_output_tensor);
-    } else {
-        aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
+        aclnn_add_alibi(ctx, src1, src0, src_tensor_buffer, max_bias);
     }
-
-    ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst,
-        acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor);
+    // softmax
+    aclnn_softmax(ctx, softmax_tensor, 3, acl_dst);
+    ggml_cann_release_resources(ctx, acl_src0, acl_dst, acl_scale, softmax_tensor);
 }
 
 /**
@@ -3208,104 +3167,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
             // Compute the slope if needed. Derived from ggml_cann_softmax().
             if(maxBias != 0.0f){
                 // alibi
-                const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
-                const int64_t n_head = src0->ne[2];
-                const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
-                float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
-                float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
-                // init arange
-                ggml_cann_pool_alloc arange_allocator(ctx.pool(),
-                                                    ne2_ne3 * faElemSize);
-                void* tmp_arange_buffer = arange_allocator.get();
-
-                // arange1: [1, ..., n_heads_log2_floor+1)
-                float start = 1;
-                float stop = n_heads_log2_floor + 1;
-                float step = 1;
-                int64_t n_elements_arange = n_heads_log2_floor;
-
-                int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
-                size_t tmp_arange1_nb[] = {faElemSize};
-                aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
-                    tmp_arange_buffer, faDataType, faElemSize,
-                    tmp_arange1_ne, tmp_arange1_nb,
-                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-
-                aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
-
-                aclTensor* tmp_arange2_tensor = nullptr;
-                if (n_heads_log2_floor < ne2_ne3) {
-                    // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
-                    start = 1;
-                    stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
-                    step = 2;
-                    n_elements_arange = ne2_ne3 - n_heads_log2_floor;
-                    int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
-                    size_t tmp_arange2_nb[] = {faElemSize};
-
-                    aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
-                        (char*)tmp_arange_buffer +
-                            n_heads_log2_floor * faElemSize,
-                        faDataType, faElemSize,
-                        tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-                    aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
-                                n_elements_arange);
+                const int64_t n_heads = src0->ne[2];
+                ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
+                void* slope_buffer = slope_allocator.get();
+                aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
+
+                int64_t slope_ne[] = {1, 1, n_heads, 1};
+                size_t slope_nb[GGML_MAX_DIMS];
+                slope_nb[0] = sizeof(float);
+                for(int i = 1;i<GGML_MAX_DIMS;i++) {
+                    slope_nb[i] = slope_nb[i-1] * slope_ne[0];
                 }
 
-                // init mk_base
-                ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
-                                                    ne2_ne3 * faElemSize);
-                void* tmp_mk_base_buffer = mk_base_allocator.get();
-                int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
-                size_t tmp_mk_base1_nb[] = {faElemSize};
-                aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
-                    tmp_mk_base_buffer, faDataType, faElemSize,
-                    tmp_mk_base1_ne, tmp_mk_base1_nb,
-                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-
-                aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
-
-                aclTensor* tmp_mk_base2_tensor = nullptr;
-                if (n_heads_log2_floor < ne2_ne3) {
-                    int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
-                    size_t tmp_mk_base2_nb[] = {faElemSize};
-                    aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
-                        (char*)tmp_mk_base_buffer +
-                            n_heads_log2_floor * faElemSize,
-                        faDataType, faElemSize,
-                        tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-                    aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
-                }
+                aclTensor* slope_tensor = ggml_cann_create_tensor(
+                    slope_buffer, ACL_FLOAT, sizeof(float),
+                    slope_ne, slope_nb, GGML_MAX_DIMS);
+                GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
 
-                // init mk
-                int64_t tmp_mk_base_ne[] = {ne2_ne3};
-                size_t tmp_mk_base_nb[] = {faElemSize};
-                aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
-                    tmp_mk_base_buffer, faDataType, faElemSize,
-                    tmp_mk_base_ne, tmp_mk_base_nb,
-                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-                aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
-                    tmp_arange_buffer, faDataType, faElemSize,
-                    tmp_mk_base_ne, tmp_mk_base_nb,
-                    GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-                aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
-
-                // reshape mk
-                int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
-                size_t tmp_mk_nb[GGML_MAX_DIMS];
-                tmp_mk_nb[0] = faElemSize;
-                for (int i = 1; i < GGML_MAX_DIMS; i++) {
-                    tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
-                }
-                aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
-                    tmp_mk_base_buffer, faDataType, faElemSize,
-                    tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
-                    ACL_FORMAT_ND);
-                GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
-
-                ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
-                    tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
-                    tmp_arange_tensor, tmp_mk_tensor);
+                ggml_cann_release_resources(ctx, slope_tensor);
             }
         }
 
index cf575b367500a1c75f422b09aad011638c543a11..3d3520f195951b76186eeb5e7019e1f6d05c1784 100755 (executable)
@@ -2391,7 +2391,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                 // only support F32 and F16.
                 return false;
             }
-            return true;
+            return ggml_is_contiguous(op);
         } break;
         case GGML_OP_CONT: {
             // TODO: support GGML_TYPE_BF16
@@ -2456,8 +2456,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
             // value of paddingW should be at most half of kernelW
             return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
         }
-        case GGML_OP_SUM:
         case GGML_OP_DUP:
+            return ggml_is_contiguous(op);
+        case GGML_OP_SUM:
         case GGML_OP_IM2COL:
         case GGML_OP_CONCAT:
         case GGML_OP_REPEAT:
@@ -2503,9 +2504,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
             if (op->src[2]) {
                 return false;
             }
-            // TODO: support broadcast
-            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
-            return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
+            return true;
         case GGML_OP_FLASH_ATTN_EXT:{
             // derived from [ggml-cuda.cu]
             if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
@@ -2532,11 +2531,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                 // DeepSeek MLA
                 return false;
             }
-            // TODO: support broadcast
-            // ref: https://github.com/ggml-org/llama.cpp/pull/14435
-            if (op->src[0]->ne[3] != 1) {
-                return false;
-            }
             float logitSoftcap = 0.0f;
             memcpy(&logitSoftcap,  (float*)op->op_params + 2, sizeof(float));
             if(logitSoftcap != 0.0f) {