ggml_cann_release_resources(ctx, src_trans_tensor);
return;
} else {
- GGML_ABORT("Unsupport dst is not tontiguous.");
+ GGML_ABORT("Unsupport dst is not contiguous.");
}
}
ggml_cann_release_resources(ctx, acl_src, acl_dst);
}
/**
- * @brief Applies the Alibi (Attention with Linear Biases) mechanism to the
- * @details This function implements the Alibi mechanism, which introduces
- * learnable biases into the attention scores to simulate relative
- * position encoding without the need for explicit positional
- * embeddings.
- *
- * @param ctx The backend CANN context for executing operations.
- * @param acl_src The source tensor representing the query or key.
- * @param acl_position The position tensor containing relative positions.
- * @param acl_dst The destination tensor where the result will be stored.
- * @param n_head The number of attention heads.
- * @param src_ne The dimensions of the source tensor.
- * @param src_nb0 The byte size of the first dimension of the source
- tensor.
- * @param max_bias The maximum bias value used in the Alibi mechanism.
- * @param dst The destination tensor object for additional metadata.
- *
- * The function performs the following steps:
- * 1. Calculates the logarithm floor of the number of heads to determine the
- base for bias calculation.
- * 2. Initializes arrays with arithmetic sequences and fills them with bias
- values.
- * 3. Computes the bias tensor based on the calculated biases and arithmetic
- sequences.
- * 4. Reshapes the bias tensor to match the dimensions of the input tensors.
- * 5. Multiplies the position tensor by the bias tensor.
- * 6. Adds the result of the multiplication to the source tensor to produce the
- final output.
+ * @brief Generate a range of values and apply a scalar base exponentiation.
+ *
+ * This function creates an evenly spaced sequence from `start` to `stop` (exclusive),
+ * with step size `step`, stores it in a temporary buffer, and then computes:
+ *
+ * @f[
+ * slope[i] = m^{\left( start + i \cdot step \right)}, \quad 0 \le i < size
+ * @f]
+ *
+ * The results are written to the provided @p slope_buffer.
+ *
+ * @param ctx CANN backend context for memory allocation and operator execution.
+ * @param slope_buffer Pointer to the output buffer (float array) for the computed slope values.
+ * @param m Scalar base for the exponentiation.
+ * @param size Number of elements in the generated sequence.
+ * @param start Starting exponent offset.
+ * @param stop Stopping exponent offset (exclusive).
+ * @param step Step size for the exponent increment.
*/
-static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
- aclTensor* acl_position, aclTensor* acl_dst,
- const int n_head, int64_t* src_ne, const size_t src_nb0,
- float max_bias, ggml_tensor* dst) {
- const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
- GGML_ASSERT(src_nb0 == sizeof(float));
- GGML_ASSERT(n_head == src_ne[2]);
-
- const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
-
- float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
- float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
- // init arange
- ggml_cann_pool_alloc arange_allocator(ctx.pool(),
- ne2_ne3 * ggml_type_size(dst->type));
- void* tmp_arange_buffer = arange_allocator.get();
+static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
+ float m, int64_t size, float start, float stop, float step){
+ int64_t ne[] = {size};
+ size_t nb[] = {sizeof(float)};
- // arange1: [1, ..., n_heads_log2_floor+1)
- float start = 1;
- float stop = n_heads_log2_floor + 1;
- float step = 1;
- int64_t n_elements_arange = n_heads_log2_floor;
+ ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
+ void* arange_buffer = arange_allocator.get();
- int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
- size_t tmp_arange1_nb[] = {sizeof(dst->type)};
- aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
- tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
- ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb,
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+ aclTensor* arange_tensor = ggml_cann_create_tensor(
+ arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
+ aclnn_arange(ctx, arange_tensor, start, stop, step, size);
- aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
-
- aclTensor* tmp_arange2_tensor = nullptr;
- if (n_heads_log2_floor < ne2_ne3) {
- // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
- start = 1;
- stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
- step = 2;
- n_elements_arange = ne2_ne3 - n_heads_log2_floor;
- int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
- size_t tmp_arange2_nb[] = {sizeof(dst->type)};
-
- aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
- (char*)tmp_arange_buffer +
- n_heads_log2_floor * ggml_type_size(dst->type),
- ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
- tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
- aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
- n_elements_arange);
- }
+ aclTensor* slope_tensor = ggml_cann_create_tensor(
+ slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
- // init mk_base
- ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
- ne2_ne3 * ggml_type_size(dst->type));
- void* tmp_mk_base_buffer = mk_base_allocator.get();
- int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
- size_t tmp_mk_base1_nb[] = {sizeof(dst->type)};
- aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
- tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
- ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb,
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+ aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
- aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
-
- aclTensor* tmp_mk_base2_tensor = nullptr;
- if (n_heads_log2_floor < ne2_ne3) {
- int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
- size_t tmp_mk_base2_nb[] = {sizeof(dst->type)};
- aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
- (char*)tmp_mk_base_buffer +
- n_heads_log2_floor * ggml_type_size(dst->type),
- ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
- tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
- aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
- }
+ GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, sc, arange_tensor, slope_tensor);
+ ggml_cann_release_resources(ctx, sc, arange_tensor, slope_tensor);
+}
- // init mk
- int64_t tmp_mk_base_ne[] = {ne2_ne3};
- size_t tmp_mk_base_nb[] = {sizeof(dst->type)};
- aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
- tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
- ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
- aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
- tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
- ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
- aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
+/**
+ * @brief Compute slope values for multiple attention heads based on ALiBi bias parameters.
+ *
+ * This function generates slope values for each attention head according to the ALiBi
+ * (Attention with Linear Biases) method. It splits the computation into two ranges depending
+ * on whether the head index is less than @p n_head_log2 or not, and uses different base values
+ * (`m0` and `m1`) for the exponentiation.
+ *
+ * @f[
+ * slope[h] =
+ * \begin{cases}
+ * m_0^{(h + 1)}, & h < n\_head\_log2 \\
+ * m_1^{\left( 2 \cdot (h - n\_head\_log2) + 1 \right)}, & h \geq n\_head\_log2
+ * \end{cases}
+ * \quad , \quad \text{if } max\_bias > 0
+ * @f]
+ *
+ * If @p max_bias <= 0, all slope values are set to 1.0.
+ *
+ * @param ctx CANN backend context for memory allocation and operator execution.
+ * @param n_head Total number of attention heads.
+ * @param slope_buffer Pointer to the output buffer (float array) for storing slopes.
+ * @param max_bias Maximum bias value for slope computation.
+ *
+*/
+static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
+ void* slope_buffer, float max_bias) {
+ const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ // const float slope = (max_bias > 0.0f) ?
+ // h < n_head_log2 ?
+ // powf(m0, h + 1) :
+ // powf(m1, 2*(h - n_head_log2) + 1) :
+ // 1.0f;
+ // arange1
+ float start = 0 + 1;
+ float end = (n_head_log2 - 1) + 1;
+ float step = 1;
+ float count = n_head_log2;
+ // end needs to be +1 because aclnn uses a left-closed, right-open interval.
+ aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step);
+ if (n_head_log2 < n_head) {
+ // arange2
+ start = 2 * (n_head_log2 - n_head_log2) + 1;
+ end = 2 * ((n_head - 1) - n_head_log2) + 1;
+ step = 2;
+ count = n_head - n_head_log2;
+ aclnn_get_slope_inner(
+ ctx, (char *) slope_buffer + n_head_log2 * sizeof(float),
+ m1, count, start, end + 1, step);
+ }
+}
- // reshape mk
- int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
- size_t tmp_mk_nb[GGML_MAX_DIMS];
- tmp_mk_nb[0] = ggml_type_size(dst->type);
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
- tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
+/**
+ * @brief Add ALiBi (Attention with Linear Biases) positional biases to the attention mask.
+ *
+ * This function computes the ALiBi slopes for each attention head (if max_bias > 0),
+ * multiplies them with the attention mask to produce bias tensors, and adds these biases
+ * to the destination tensor (@p dst).
+ *
+ * The function performs necessary broadcasting of the mask and slope tensors to match
+ * the shape of the destination tensor, then applies element-wise multiplication and addition
+ * using CANN operators.
+ *
+ * @param ctx CANN backend context for memory management and operator execution.
+ * @param mask Input attention mask tensor, assumed to be contiguous.
+ * @param dst Destination tensor to which ALiBi biases will be added.
+ * @param dst_ptr Pointer to the memory of the destination tensor.
+ * @param max_bias Maximum bias value controlling the slope scaling.
+ *
+ * @note
+ * - Write data into dst_ptr using only the shape information of the dst tensor.
+ * - `GGML_MAX_DIMS + 2` is used to extend tensor dimensions for broadcasting.
+ */
+static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask,
+ ggml_tensor* dst, void* dst_ptr, float max_bias) {
+ void* slope_buffer = nullptr;
+ void* bias_buffer = nullptr;
+
+ if (max_bias > 0.0f) {
+ int64_t n_heads = dst->ne[2];
+ ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
+ slope_buffer = slope_allocator.get();
+ ggml_cann_pool_alloc bias_allocator(
+ ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));
+ bias_buffer = bias_allocator.get();
+ aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias);
}
- aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
- tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
- ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
- ACL_FORMAT_ND);
- // acl_position * mk
- int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
- size_t tmp_output_nb[GGML_MAX_DIMS];
- tmp_output_nb[0] = ggml_type_size(dst->type);
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
- tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1];
+ // broadcast for mask, slop and dst;
+ int64_t nr2 = dst->ne[2] / mask->ne[2];
+ int64_t nr3 = dst->ne[3] / mask->ne[3];
+
+ // broadcast the mask across rows
+ int64_t mask_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], 1, mask->ne[3], 1 };
+ size_t mask_nb[] = {
+ mask_nb[0] = mask->nb[0], mask_nb[1] = mask->nb[1], mask_nb[2] = mask->nb[2],
+ mask_nb[3] = mask->nb[2], mask_nb[4] = mask->nb[3], mask_nb[5] = mask->nb[3]
+ };
+
+ int64_t dst_ne[] = { dst->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], nr3 };
+ size_t dst_nb[] = {
+ dst_nb[0] = dst->nb[0], dst_nb[1] = dst->nb[1], dst_nb[2] = dst->nb[2],
+ dst_nb[3] = dst->nb[2], dst_nb[4] = dst->nb[3], dst_nb[5] = dst->nb[3]
+ };
+
+ // slope is a 1 dim tensor, slope.ne2 == dst.ne2
+ int64_t slope_ne[] = { 1, 1, mask->ne[2], nr2, 1, 1 };
+ size_t slope_nb[GGML_MAX_DIMS + 2];
+ slope_nb[0] = sizeof(float);
+ for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
+ slope_nb[i] = slope_nb[i - 1] * slope_ne[i - 1];
}
- ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst));
- void* tmp_output_buffer = output_allocator.get();
- aclTensor* tmp_output_tensor = ggml_cann_create_tensor(
- tmp_output_buffer, ggml_cann_type_mapping(dst->type),
- ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS,
- ACL_FORMAT_ND);
- aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor);
- // add
- aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
- ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
- tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
- tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor);
+ aclTensor* acl_slope = ggml_cann_create_tensor(
+ slope_buffer, ACL_FLOAT, sizeof(float),
+ slope_ne, slope_nb, GGML_MAX_DIMS + 2);
+ aclTensor* acl_mask = ggml_cann_create_tensor(
+ mask, mask_ne, mask_nb, GGML_MAX_DIMS + 2);
+
+ // write data into dst_ptr using only the shape information of the dst tensor.
+ aclTensor* acl_dst = ggml_cann_create_tensor(
+ dst_ptr, ggml_cann_type_mapping(dst->type),
+ ggml_type_size(dst->type), dst_ne, dst_nb,
+ GGML_MAX_DIMS + 2);
+
+ if (max_bias > 0.0f) {
+ int64_t bias_ne[] = { mask->ne[0], dst->ne[1], mask->ne[2], nr2, mask->ne[3], 1 };
+ size_t bias_nb[GGML_MAX_DIMS + 2];
+ bias_nb[0] = sizeof(float);
+ for (int i = 1; i < GGML_MAX_DIMS + 2; i++) {
+ bias_nb[i] = bias_nb[i - 1] * bias_ne[i - 1];
+ }
+ aclTensor* bias_tensor = ggml_cann_create_tensor(
+ bias_buffer, ACL_FLOAT, sizeof(float),
+ bias_ne, bias_nb, GGML_MAX_DIMS + 2);
+
+ aclnn_mul(ctx, acl_slope, acl_mask, bias_tensor);
+ aclnn_add(ctx, acl_dst, bias_tensor);
+ ggml_cann_release_resources(ctx, bias_tensor);
+ } else {
+ aclnn_add(ctx, acl_dst, acl_mask);
+ }
+ ggml_cann_release_resources(ctx, acl_slope, acl_mask, acl_dst);
}
-void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_cann_dup(ctx, dst);
}
* @param acl_dst The destination tensor where the softmax results will be
* stored.
*/
-static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
- int64_t dim, aclTensor* acl_dst) {
+static void aclnn_softmax(ggml_backend_cann_context & ctx,
+ aclTensor* acl_src, int64_t dim, aclTensor * acl_dst) {
GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst);
}
-void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor* src0 = dst->src[0];
ggml_tensor* src1 = dst->src[1]; // mask
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
- aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+ aclTensor* acl_dst = ggml_cann_create_tensor(dst);
- float scale = 1.0f;
+ float scale = 1.0f;
float max_bias = 0.0f;
- memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
- memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
// input mul scale
aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
+ ggml_cann_pool_alloc src_tensor_allocator(ctx.pool(), ggml_nbytes(src0));
+ void* src_tensor_buffer = src_tensor_allocator.get();
+ aclTensor* softmax_tensor = ggml_cann_create_tensor(
+ src_tensor_buffer, ggml_cann_type_mapping(src0->type),
+ ggml_element_size(src0), src0->ne, src0->nb,GGML_MAX_DIMS);
- size_t n_bytes = ggml_nbytes(src0);
- ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes);
- void* input_mul_scale_buffer = mul_scale_allocator.get();
- aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor(
- input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne,
- src0->nb, GGML_MAX_DIMS);
-
- bool inplace = false;
- aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace);
+ aclnn_muls(ctx, acl_src0, scale, softmax_tensor, false);
// mask
- aclTensor* acl_src1_fp32_tensor = nullptr;
- aclTensor* tmp_mask_tensor = nullptr;
- ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool());
if (src1) {
- const bool use_f16 = src1->type == GGML_TYPE_F16;
- if (use_f16) {
- // cast to fp32
- size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
- size_t src1_fp32_nb[GGML_MAX_DIMS];
- src1_fp32_nb[0] = sizeof(float_t);
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
- src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
- }
- src1_fp32_allocator.alloc(n_bytes);
- void* src1_fp32_buffer = src1_fp32_allocator.get();
- acl_src1_fp32_tensor = ggml_cann_create_tensor(
- src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne,
- src1_fp32_nb, GGML_MAX_DIMS);
- aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
- aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
- ggml_cann_release_resources(ctx, acl_src1);
- } else {
- acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
- }
-
- // broadcast the mask across rows, only use ne11 of ne01 in mask
- if (src1->ne[1] != src0->ne[1]) {
- // mask shape: [1,1,ne11,ne10]
- int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1};
- size_t tmp_mask_nb[GGML_MAX_DIMS];
- tmp_mask_nb[0] = sizeof(float_t);
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
- tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1];
- }
- tmp_mask_tensor = ggml_cann_create_tensor(
- src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb,
- GGML_MAX_DIMS, ACL_FORMAT_ND);
- }
-
- // alibi
- const int n_head = src0->ne[2];
- const size_t src_nb0 = src0->nb[0];
-
- n_bytes = ggml_nbytes(dst);
- ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes);
- void* output_buffer = output_allocator.get();
- aclTensor* alibi_output_tensor = ggml_cann_create_tensor(
- output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne,
- dst->nb, GGML_MAX_DIMS);
- if (max_bias <= 0.0f) {
- // slope = 1.0
- if (tmp_mask_tensor) {
- aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor,
- alibi_output_tensor);
- } else {
- aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
- alibi_output_tensor);
- }
- } else {
- // slope != 1.0
- if (tmp_mask_tensor) {
- aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor,
- alibi_output_tensor, n_head, src0->ne, src_nb0,
- max_bias, dst);
- } else {
- aclnn_alibi(ctx, acl_input_mul_scale_tensor,
- acl_src1_fp32_tensor, alibi_output_tensor, n_head,
- src0->ne, src_nb0, max_bias, dst);
- }
- }
-
- // softmax
- aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
- ggml_cann_release_resources(ctx, alibi_output_tensor);
- } else {
- aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
+ aclnn_add_alibi(ctx, src1, src0, src_tensor_buffer, max_bias);
}
-
- ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst,
- acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor);
+ // softmax
+ aclnn_softmax(ctx, softmax_tensor, 3, acl_dst);
+ ggml_cann_release_resources(ctx, acl_src0, acl_dst, acl_scale, softmax_tensor);
}
/**
// Compute the slope if needed. Derived from ggml_cann_softmax().
if(maxBias != 0.0f){
// alibi
- const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
- const int64_t n_head = src0->ne[2];
- const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
- float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
- float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
- // init arange
- ggml_cann_pool_alloc arange_allocator(ctx.pool(),
- ne2_ne3 * faElemSize);
- void* tmp_arange_buffer = arange_allocator.get();
-
- // arange1: [1, ..., n_heads_log2_floor+1)
- float start = 1;
- float stop = n_heads_log2_floor + 1;
- float step = 1;
- int64_t n_elements_arange = n_heads_log2_floor;
-
- int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
- size_t tmp_arange1_nb[] = {faElemSize};
- aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
- tmp_arange_buffer, faDataType, faElemSize,
- tmp_arange1_ne, tmp_arange1_nb,
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-
- aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
-
- aclTensor* tmp_arange2_tensor = nullptr;
- if (n_heads_log2_floor < ne2_ne3) {
- // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
- start = 1;
- stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
- step = 2;
- n_elements_arange = ne2_ne3 - n_heads_log2_floor;
- int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
- size_t tmp_arange2_nb[] = {faElemSize};
-
- aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
- (char*)tmp_arange_buffer +
- n_heads_log2_floor * faElemSize,
- faDataType, faElemSize,
- tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
- aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
- n_elements_arange);
+ const int64_t n_heads = src0->ne[2];
+ ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
+ void* slope_buffer = slope_allocator.get();
+ aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
+
+ int64_t slope_ne[] = {1, 1, n_heads, 1};
+ size_t slope_nb[GGML_MAX_DIMS];
+ slope_nb[0] = sizeof(float);
+ for(int i = 1;i<GGML_MAX_DIMS;i++) {
+ slope_nb[i] = slope_nb[i-1] * slope_ne[0];
}
- // init mk_base
- ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
- ne2_ne3 * faElemSize);
- void* tmp_mk_base_buffer = mk_base_allocator.get();
- int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
- size_t tmp_mk_base1_nb[] = {faElemSize};
- aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
- tmp_mk_base_buffer, faDataType, faElemSize,
- tmp_mk_base1_ne, tmp_mk_base1_nb,
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
-
- aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
-
- aclTensor* tmp_mk_base2_tensor = nullptr;
- if (n_heads_log2_floor < ne2_ne3) {
- int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
- size_t tmp_mk_base2_nb[] = {faElemSize};
- aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
- (char*)tmp_mk_base_buffer +
- n_heads_log2_floor * faElemSize,
- faDataType, faElemSize,
- tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
- aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
- }
+ aclTensor* slope_tensor = ggml_cann_create_tensor(
+ slope_buffer, ACL_FLOAT, sizeof(float),
+ slope_ne, slope_nb, GGML_MAX_DIMS);
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
- // init mk
- int64_t tmp_mk_base_ne[] = {ne2_ne3};
- size_t tmp_mk_base_nb[] = {faElemSize};
- aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
- tmp_mk_base_buffer, faDataType, faElemSize,
- tmp_mk_base_ne, tmp_mk_base_nb,
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
- aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
- tmp_arange_buffer, faDataType, faElemSize,
- tmp_mk_base_ne, tmp_mk_base_nb,
- GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
- aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
-
- // reshape mk
- int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
- size_t tmp_mk_nb[GGML_MAX_DIMS];
- tmp_mk_nb[0] = faElemSize;
- for (int i = 1; i < GGML_MAX_DIMS; i++) {
- tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
- }
- aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
- tmp_mk_base_buffer, faDataType, faElemSize,
- tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
- ACL_FORMAT_ND);
- GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
-
- ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
- tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
- tmp_arange_tensor, tmp_mk_tensor);
+ ggml_cann_release_resources(ctx, slope_tensor);
}
}