]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CANN: Add ROPE sin/cos cache for reuse (llama/15912)
authorChenguang Li <redacted>
Wed, 10 Sep 2025 10:42:00 +0000 (18:42 +0800)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* CANN: Add ROPE sin/cos cache for reuse

Introduce sin/cos caching mechanism in ROPE to avoid redundant
computation across layers. The cache is built on the first layer
per device and reused by subsequent layers if parameters match.

- Added sin_cache / cos_cache pointers and position_length tracking
- Introduced cache validity flags and properties:
  (ext_factor, theta_scale, freq_scale, attn_factor, is_neox)
- Accelerates ROPE by eliminating repeated sin/cos generation

This change reduces overhead in multi-layer scenarios while
preserving correctness by verifying parameter consistency.

Co-authored-by: hipudding <redacted>
* fix typo

Signed-off-by: noemotiovon <redacted>
---------

Signed-off-by: noemotiovon <redacted>
Co-authored-by: hipudding <redacted>
src/ggml-cann/aclnn_ops.cpp
src/ggml-cann/common.h
src/ggml-cann/ggml-cann.cpp

index ac2e2e1adf3bb8175f58baa51a945956013bbc8a..434023dd22ab3f9062de06935f9c1b1500c4a52e 100755 (executable)
@@ -2268,8 +2268,6 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
  *                           stream, and persistent buffers for rope init/cache.
  * @param dst                The destination ggml_tensor whose computation
  *                           depends on the RoPE values (usually Qcur/Kcur).
- * @param sin_tensor_buffer  Pre-allocated buffer for storing repeated sin values.
- * @param cos_tensor_buffer  Pre-allocated buffer for storing repeated cos values.
  * @param theta_scale        Scalar exponent base for computing theta scale values.
  * @param freq_scale         Frequency scaling factor, applied to theta scale.
  * @param attn_factor        Attention scaling factor, applied to sin/cos.
@@ -2277,17 +2275,23 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
  *                           (dim expansion vs repeat_interleave).
  */
 static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
-                             void* sin_tensor_buffer, void* cos_tensor_buffer,
                              float* corr_dims, float ext_factor,
                              float theta_scale, float freq_scale,
                              float attn_factor, bool is_neox) {
-    // int sin/cos cache, cache has different repeat method depond on
-    // @param.is_neox
-
     ggml_tensor* src0 = dst->src[0];  // input
     ggml_tensor* src1 = dst->src[1];  // position
     ggml_tensor* src2 = dst->src[2];  // freq_factors
 
+    if(src2 == nullptr && ctx.rope_cache.cached
+        && ctx.rope_cache.ext_factor == ext_factor
+        && ctx.rope_cache.theta_scale == theta_scale
+        && ctx.rope_cache.freq_scale == freq_scale
+        && ctx.rope_cache.attn_factor == attn_factor
+        && ctx.rope_cache.is_neox == is_neox) {
+        // use cache.
+        return;
+    }
+
     int64_t theta_scale_length = src0->ne[0] / 2;
     int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
     size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float),
@@ -2316,8 +2320,6 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
         ctx.rope_cache.freq_scale != freq_scale) {
 
         ctx.rope_cache.theta_scale_length = theta_scale_length;
-        ctx.rope_cache.theta_scale = theta_scale;
-        ctx.rope_cache.freq_scale = freq_scale;
 
         if (ctx.rope_cache.theta_scale_cache != nullptr) {
             ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
@@ -2342,7 +2344,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
             // return MIN(1, MAX(0, y)) - 1;
             yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
             void* yarn_ramp_buffer = yarn_ramp_allocator.get();
-            acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t),
+            acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float),
                                            theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
             float zero_value = 0, one_value = 1;
             float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
@@ -2411,6 +2413,20 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
         ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
     }
 
+    // init sin_repeat && cos_repeat, only to accelerate first layer on each device
+    if (position_length > ctx.rope_cache.position_length) {
+        ctx.rope_cache.position_length = position_length;
+        if (ctx.rope_cache.sin_cache != nullptr) {
+            ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache));
+        }
+        if (ctx.rope_cache.cos_cache != nullptr) {
+            ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache));
+        }
+        int64_t repeat_theta_length = theta_scale_length * position_length * 2;
+        ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
+        ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
+    }
+
     // position
     aclTensor* acl_position_tensor = ggml_cann_create_tensor(
         src1->data, ggml_cann_type_mapping(src1->type),
@@ -2462,10 +2478,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
         sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
     }
     aclTensor* acl_sin_repeat_tensor =
-        ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
+        ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),
                                 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
     aclTensor* acl_cos_repeat_tensor =
-        ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
+        ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
                                 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
 
     // repeat
@@ -2483,6 +2499,14 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
                                 num_repeats, output_size);
     }
 
+    // Other layers use cache except first layer.
+    ctx.rope_cache.cached = true;
+    ctx.rope_cache.ext_factor = ext_factor;
+    ctx.rope_cache.theta_scale = theta_scale;
+    ctx.rope_cache.freq_scale = freq_scale;
+    ctx.rope_cache.attn_factor = attn_factor;
+    ctx.rope_cache.is_neox = is_neox;
+
     ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
         acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
         acl_cos_repeat_tensor);
@@ -2504,10 +2528,7 @@ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace,
 #endif
 
 void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
-    // TODO: use ascendc
-    // Only test with LLAMA model.
     ggml_tensor* src0 = dst->src[0];  // input
-    ggml_tensor* src1 = dst->src[1];
 
     // param
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2538,15 +2559,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
 
-    // sin/cos tensor length.
-    int64_t repeat_theta_length = src0->ne[0] * src1->ne[0];
-    ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
-    ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
-    void *sin_tensor_buffer = sin_tensor_allocator.get();
-    void *cos_tensor_buffer = cos_tensor_allocator.get();
-
     // init ctx.rope_cos/rope_sin cache
-    aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
+    aclnn_cache_init(ctx, dst, corr_dims, ext_factor,
                     theta_scale, freq_scale, attn_factor, is_neox);
 
     int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
@@ -2556,10 +2570,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
     }
     aclTensor* acl_sin_reshape_tensor =
-        ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
+        ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),
                                 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
     aclTensor* acl_cos_reshape_tensor =
-        ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
+        ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
                                 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
 
     aclTensor* acl_src = ggml_cann_create_tensor(src0);
index 17d7dbc75c244b8ed4c579514101420ad45ada84..c5fce8dc91f51c0316861c0a775b5ea50ff0b2b4 100755 (executable)
@@ -425,12 +425,27 @@ struct ggml_cann_rope_cache {
         if(theta_scale_cache != nullptr) {
             ACL_CHECK(aclrtFree(theta_scale_cache));
         }
+        if(sin_cache != nullptr) {
+            ACL_CHECK(aclrtFree(sin_cache));
+        }
+        if(cos_cache != nullptr) {
+            ACL_CHECK(aclrtFree(cos_cache));
+        }
     }
 
     void* theta_scale_cache = nullptr;
     int64_t theta_scale_length = 0;
+    // sin/cos cache, used only to accelerate first layer on each device
+    void* sin_cache = nullptr;
+    void* cos_cache = nullptr;
+    int64_t position_length = 0;
+    // Properties to check before reusing the sincos cache
+    bool cached = false;
+    float ext_factor = 0.0f;
     float theta_scale = 0.0f;
     float freq_scale = 0.0f;
+    float attn_factor = 0.0f;
+    bool is_neox = false;
 };
 
 struct ggml_cann_tensor_cache {
index aa5913a3776df5d31fb1c138a68d065da2090c3d..d148174f1e84f253b8b3625493110789ca88eeb5 100755 (executable)
@@ -2353,6 +2353,9 @@ static enum ggml_status ggml_backend_cann_graph_compute(
     ggml_cann_set_device(cann_ctx->device);
     g_nz_workspaces[cann_ctx->device].clear();
 
+    // calculate rope cache for fist layer in current device.
+    cann_ctx->rope_cache.cached = false;
+
 #ifdef USE_ACL_GRAPH
     bool use_cann_graph = true;
     bool cann_graph_update_required = false;