#include <aclnnop/aclnn_gt_scalar.h>
#include <aclnnop/aclnn_pow.h>
#include <aclnnop/aclnn_grouped_matmul_v2.h>
+#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
#include <float.h>
#include <cmath>
#include <vector>
#include "ggml-impl.h"
+#include "ggml.h"
#define GGML_COMMON_DECL_C
#include "../ggml-common.h"
+
void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
aclTensor ** acl_src1, aclTensor ** acl_dst) {
GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0));
break;
}
}
+
+void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
+
+ ggml_tensor* src0 = dst->src[0]; // q, fp32
+ ggml_tensor* src1 = dst->src[1]; // k, fp16
+ ggml_tensor* src2 = dst->src[2]; // v, fp16
+ ggml_tensor* src3 = dst->src[3]; // mask, fp16
+
+ float maxBias = 0.0f;
+ float scaleValue = 1.0f;
+ float logitSoftcap = 0.0f;
+ memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float));
+ memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float));
+ memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float));
+
+ if(logitSoftcap == 0.0f){
+ size_t faElemSize = sizeof(uint16_t);
+ auto faDataType = ACL_FLOAT16; //ACL_BF16;
+
+ aclTensor* acl_src0_f16_tensor = nullptr;
+ aclTensor* acl_src1_f16_tensor = nullptr;
+ aclTensor* acl_src2_f16_tensor = nullptr;
+ aclTensor* acl_dst_f16_tensor = nullptr;
+
+ // Step 1: cast the src0 (Query) to fp16 if needed
+ ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
+ void* src0_f16_buffer = nullptr;
+
+ if(ggml_cann_type_mapping(src0->type) != faDataType){
+ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
+ src0_f16_buffer = src0_f16_allocator.alloc(
+ ggml_nelements(src0) * faElemSize);
+
+ int64_t* src0_f16_ne = src0->ne;
+ size_t src0_f16_nb[GGML_MAX_DIMS];
+ src0_f16_nb[0] = sizeof(uint16_t);
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
+ src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1];
+ }
+
+ acl_src0_f16_tensor = ggml_cann_create_tensor(
+ src0_f16_buffer, faDataType, faElemSize,
+ src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS
+ );
+ aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
+ ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
+ }else{
+ acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
+ }
+
+ // Step 2: create the acl tensors for src1 (Key), src2 (Value),
+ // and the direct output from FusedInferAttention
+
+ acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
+ acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
+
+ ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
+ void* out_f16_buffer = out_f16_allocator.alloc(
+ ggml_nelements(dst) * faElemSize);
+
+ int64_t* out_f16_ne = src0->ne;
+ size_t out_f16_nb[GGML_MAX_DIMS];
+ out_f16_nb[0] = faElemSize;
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
+ out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
+ }
+
+ acl_dst_f16_tensor = ggml_cann_create_tensor(
+ out_f16_buffer, faDataType, faElemSize,
+ out_f16_ne, out_f16_nb, GGML_MAX_DIMS
+ );
+
+ // Step 3: create the PSEShift tensor if needed
+ // this tensor is considered as mask (f16) in the llama.cpp
+
+ aclTensor* bcast_pse_tensor = nullptr;
+ int64_t bcast_pse_ne[GGML_MAX_DIMS];
+ size_t bcast_pse_nb[GGML_MAX_DIMS];
+ ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
+ void* bcast_pse_buffer = nullptr;
+
+ if(src3 != nullptr){
+ bcast_pse_buffer = bcast_pse_allocator.alloc(
+ ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
+
+ if(src0->ne[1] > 1){
+ // Case 1: broadcast pse for prefill stage with multiple head
+ aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
+ bcast_pse_ne[0] = src3->ne[0];
+ bcast_pse_ne[1] = src3->ne[1];
+ bcast_pse_ne[2] = src0->ne[2];
+ bcast_pse_ne[3] = src3->ne[3];
+
+ bcast_pse_nb[0] = sizeof(uint16_t);
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
+ bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
+ }
+
+ bcast_pse_tensor = ggml_cann_create_tensor(
+ bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
+
+ int64_t repeats[] = {1, src0->ne[2], 1, 1};
+ aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
+
+ ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
+ }else{
+ // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
+ int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
+ size_t* trunc_pse_nb = src3->nb;
+
+ aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
+ src3->data, ACL_FLOAT16, sizeof(uint16_t),
+ trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
+
+ bcast_pse_ne[0] = src3->ne[0];
+ bcast_pse_ne[1] = src0->ne[1];
+ bcast_pse_ne[2] = src0->ne[2];
+ bcast_pse_ne[3] = src3->ne[3];
+
+ bcast_pse_nb[0] = sizeof(uint16_t);
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
+ bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
+ }
+
+ bcast_pse_tensor = ggml_cann_create_tensor(
+ bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
+
+ int64_t repeats[] = {1, src0->ne[2], 1, 1};
+ aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
+
+ ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
+ }
+
+ // Compute the slope if needed. Derived from ggml_cann_softmax().
+ if(maxBias != 0.0f){
+ // alibi
+ const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
+ const int64_t n_head = src0->ne[2];
+ const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
+ float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
+ float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
+ // init arange
+ ggml_cann_pool_alloc arange_allocator(ctx.pool(),
+ ne2_ne3 * faElemSize);
+ void* tmp_arange_buffer = arange_allocator.get();
+
+ // arange1: [1, ..., n_heads_log2_floor+1)
+ float start = 1;
+ float stop = n_heads_log2_floor + 1;
+ float step = 1;
+ int64_t n_elements_arange = n_heads_log2_floor;
+
+ int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
+ size_t tmp_arange1_nb[] = {faElemSize};
+ aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
+ tmp_arange_buffer, faDataType, faElemSize,
+ tmp_arange1_ne, tmp_arange1_nb,
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+ aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
+
+ aclTensor* tmp_arange2_tensor = nullptr;
+ if (n_heads_log2_floor < ne2_ne3) {
+ // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
+ start = 1;
+ stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
+ step = 2;
+ n_elements_arange = ne2_ne3 - n_heads_log2_floor;
+ int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
+ size_t tmp_arange2_nb[] = {faElemSize};
+
+ aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
+ (char*)tmp_arange_buffer +
+ n_heads_log2_floor * faElemSize,
+ faDataType, faElemSize,
+ tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+ aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
+ n_elements_arange);
+ }
+
+ // init mk_base
+ ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
+ ne2_ne3 * faElemSize);
+ void* tmp_mk_base_buffer = mk_base_allocator.get();
+ int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
+ size_t tmp_mk_base1_nb[] = {faElemSize};
+ aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
+ tmp_mk_base_buffer, faDataType, faElemSize,
+ tmp_mk_base1_ne, tmp_mk_base1_nb,
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+ aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
+
+ aclTensor* tmp_mk_base2_tensor = nullptr;
+ if (n_heads_log2_floor < ne2_ne3) {
+ int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
+ size_t tmp_mk_base2_nb[] = {faElemSize};
+ aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
+ (char*)tmp_mk_base_buffer +
+ n_heads_log2_floor * faElemSize,
+ faDataType, faElemSize,
+ tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+ aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
+ }
+
+ // init mk
+ int64_t tmp_mk_base_ne[] = {ne2_ne3};
+ size_t tmp_mk_base_nb[] = {faElemSize};
+ aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
+ tmp_mk_base_buffer, faDataType, faElemSize,
+ tmp_mk_base_ne, tmp_mk_base_nb,
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+ aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
+ tmp_arange_buffer, faDataType, faElemSize,
+ tmp_mk_base_ne, tmp_mk_base_nb,
+ GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+ aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
+
+ // reshape mk
+ int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
+ size_t tmp_mk_nb[GGML_MAX_DIMS];
+ tmp_mk_nb[0] = faElemSize;
+ for (int i = 1; i < GGML_MAX_DIMS; i++) {
+ tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
+ }
+ aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
+ tmp_mk_base_buffer, faDataType, faElemSize,
+ tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
+ ACL_FORMAT_ND);
+ GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
+
+ ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
+ tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
+ tmp_arange_tensor, tmp_mk_tensor);
+ }
+ }
+
+ // Step 4: set the inputs for FusedInferAttention.
+ int kvTensorNum = 1;
+ aclTensor* acl_q_tensor = acl_src0_f16_tensor;
+ aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
+ aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
+ auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
+ auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
+
+ int64_t numHeads = src0->ne[2]; // N
+ int64_t numKeyValueHeads = src1->ne[2];
+ // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
+ int64_t preTokens = 65535;
+ int64_t nextTokens = 65535;
+ char layout[5] = {'B', 'N', 'S', 'D', 0};
+ int64_t sparseMode = 0;
+ int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
+ int64_t blockSize = 0;
+ int64_t antiquantMode = 0;
+ bool softmaxLseFlag = false;
+ int64_t keyAntiquantMode = 0;
+ int64_t valueAntiquantMode = 0;
+
+ // Step 5: launch the FusedInferAttentionScoreV2 kernel.
+ // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
+
+ GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
+ acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
+ bcast_pse_tensor, nullptr, // pse, mask
+ nullptr, nullptr, // actSeqLen, actSeqLenkv
+ nullptr, nullptr, // deqScale1, quantScale1
+ nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
+ nullptr, nullptr, // antiquantScale, antiquantOffset
+ nullptr, // blockTable
+ nullptr, nullptr, // qPadSize, kvPadSize
+ nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
+ nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
+ nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
+ numHeads, scaleValue, // heads, scaleValue
+ preTokens, nextTokens, // preTokens, nextTokens
+ layout, // inputLayout
+ numKeyValueHeads, // numKVHeads
+ sparseMode, innerPrecise, // sparseMode, innerPrecise
+ blockSize, antiquantMode, // blockSize, antiquantMode
+ softmaxLseFlag, // softmaxLseFlag
+ keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
+ acl_dst_f16_tensor, // attentionOut
+ nullptr // softmaxLse
+ );
+
+ // Step 6: post-processing, permute and cast to f32
+
+ int64_t new_dim[] = {0, 2, 1, 3};
+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
+
+ if(ggml_cann_type_mapping(dst->type) != faDataType){
+ ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
+ perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
+ void* perm_out_f16_buffer = perm_out_f16_allocator.get();
+
+ int64_t* perm_out_f16_ne = dst->ne;
+ size_t perm_out_f16_nb[GGML_MAX_DIMS];
+ perm_out_f16_nb[0] = faElemSize;
+ for(int i = 1; i < GGML_MAX_DIMS; ++i){
+ perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
+ }
+ aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
+ perm_out_f16_buffer, faDataType, faElemSize,
+ perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
+ aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
+ aclnn_cast(ctx,
+ acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
+ ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
+ }else{
+ // only need to permute
+ aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
+ }
+ ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
+ acl_src1_f16_tensor,
+ acl_src2_f16_tensor,
+ acl_dst_f16_tensor,
+ acl_dst_tensor);
+ if(src3 != nullptr){
+ ggml_cann_release_resources(ctx, bcast_pse_tensor);
+ }
+ }else{
+ GGML_ABORT("Function is not implemented.");
+ }
+}