// Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
// TODO: acl_yarn_ramp_tensor use rope cache.
bool yarn_ramp_tensor_updated = false;
- ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
acl_tensor_ptr acl_yarn_ramp_tensor;
if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
ctx.rope_cache.freq_scale != freq_scale)) {
yarn_ramp_tensor_updated = true;
-
+ if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
+ ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
+ }
+ ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
- yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
- void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor =
- ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
+ ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
+ } else {
+ acl_yarn_ramp_tensor =
+ ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
}
-
// Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
if (ext_factor != 0) {
if (theta_scale_updated || yarn_ramp_tensor_updated) {
if (position_select_index_host) {
free(position_select_index_host);
}
+ if (yarn_ramp_cache) {
+ ACL_CHECK(aclrtFree(yarn_ramp_cache));
+ }
}
bool equal(int64_t theta_scale_length,
float * theta_scale_exp_host = nullptr;
int * position_select_index_host = nullptr;
void * position_select_index = nullptr;
+ void * yarn_ramp_cache = nullptr;
// sin/cos cache, used only to accelerate first layer on each device
void * sin_cache = nullptr;
void * cos_cache = nullptr;