]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CANN: ROPE cache sin/cos repeat (#15501)
authorChenguang Li <redacted>
Mon, 25 Aug 2025 02:32:21 +0000 (10:32 +0800)
committerGitHub <redacted>
Mon, 25 Aug 2025 02:32:21 +0000 (10:32 +0800)
Signed-off-by: noemotiovon <redacted>
ggml/src/ggml-cann/aclnn_ops.cpp
ggml/src/ggml-cann/common.h

index 8f65904b8fe51ef7049af3784be2940f075c750f..bc33b99d96ea686ed1f42430cdfe7f4cd92a6a68 100755 (executable)
@@ -1257,12 +1257,20 @@ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
 
 void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                       aclTensor* acl_dst) {
-    GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
+    if(acl_dst == nullptr) {
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCos, acl_src);
+    } else {
+        GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
+    }
 }
 
 void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                       aclTensor* acl_dst) {
-    GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
+    if(acl_dst == nullptr) {
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSin, acl_src);
+    } else {
+        GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
+    }
 }
 
 void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
@@ -2221,13 +2229,54 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
     ggml_cann_release_resources(ctx, acl_index, acl_value);
 }
 
+/**
+ * @brief Initializes and caches sine/cosine positional encoding values
+ *        (used in RoPE, Rotary Position Embedding) for attention layers.
+ *
+ * This function computes and caches the sin/cos values of
+ * θ = position * theta_scale for RoPE encoding. The cache is shared
+ * across attention layers, and only the first attention layer will
+ * trigger initialization. The cache includes repeated sin/cos values
+ * with different repeat methods depending on the @param is_neox flag.
+ *
+ * Steps performed by this function:
+ *   1. Identify whether the target tensor belongs to Q/K in attention
+ *      and restrict computation to the first layer only.
+ *   2. Initialize the theta scale array (arange → power → freq scaling).
+ *   3. Allocate sin/cos caches if the max prompt length increases.
+ *   4. Compute θ = position * theta_scale.
+ *   5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
+ *   6. Expand sin/cos values by repeat or repeat_interleave depending
+ *      on whether @param is_neox is enabled.
+ *   7. Store the computed values into persistent buffers
+ *      (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
+ *
+ * @param ctx         The CANN backend context, holding memory pool,
+ *                    stream, and persistent buffers for rope init/cache.
+ * @param dst         The destination ggml_tensor whose computation
+ *                    depends on the cached RoPE values (usually Qcur/Kcur).
+ * @param theta_scale Scalar exponent base for computing theta scale values.
+ * @param freq_scale  Frequency scaling factor, applied to theta scale.
+ * @param attn_factor Attention scaling factor, applied to sin/cos.
+ * @param is_neox     Whether to use Neox-style repeat strategy
+ *                    (dim expansion vs repeat_interleave).
+ */
 static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
-                             aclTensor* acl_cos_repeat_tensor,
-                             aclTensor* acl_sin_repeat_tensor,
                              float theta_scale, float freq_scale,
                              float attn_factor, bool is_neox) {
     // int sin/cos cache, cache has different repeat method depond on
     // @param.is_neox
+    bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
+    bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
+
+    // used for accuracy testing
+    bool is_attention = is_q || is_k;
+
+    // just compute in first layer in attention
+    bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
+    if(is_attention && !is_fisrt_layer) {
+        return;
+    }
 
     ggml_tensor* src0 = dst->src[0];  // input
     ggml_tensor* src1 = dst->src[1];  // position
@@ -2253,21 +2302,16 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
         theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
     }
 
-    bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
-    bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
-
-    // used for accuracy testing
-    bool is_attention = is_q || is_k;
-
-    if(ctx.init_ptr == nullptr || !is_attention) {
+    // init theta scale, just one time
+    if(ctx.rope_init_ptr == nullptr || !is_attention) {
         // theta_scale arange, [0,1,...,ne00/2 - 1]
-        if(ctx.init_ptr != nullptr){
-            ACL_CHECK(aclrtFree(ctx.init_ptr));
+        if(ctx.rope_init_ptr != nullptr){
+            ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
         }
-        ACL_CHECK(aclrtMalloc(&ctx.init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
+        ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
 
         aclTensor* acl_theta_scale_tensor =
-            ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
+            ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
                                     theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
         float start = 0;
         float step = 1;
@@ -2297,67 +2341,55 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
         ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
     }
 
-    if(ctx.sin_ptr == nullptr) {
-        int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
-        ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
-        ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
-    }
+    // init sin_repeat && cos_repeat, one token just init in 0 layer
     if(position_length > ctx.max_prompt_length) {
         ctx.max_prompt_length = position_length;
-        int64_t theta_length = theta_scale_length * ctx.max_prompt_length;
-        ACL_CHECK(aclrtFree(ctx.sin_ptr));
-        ACL_CHECK(aclrtFree(ctx.cos_ptr));
-        ACL_CHECK(aclrtMalloc(&ctx.sin_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
-        ACL_CHECK(aclrtMalloc(&ctx.cos_ptr, theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
+        int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
+        if(ctx.rope_sin_ptr != nullptr) {
+            ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
+            ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
+        }
+        ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
+        ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
     }
 
-    bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
-
-    if(is_fisrt_layer || !is_attention) {
-
-        aclTensor* acl_theta_scale_tensor =
-            ggml_cann_create_tensor(ctx.init_ptr, ACL_FLOAT, sizeof(float_t),
+    aclTensor* acl_theta_scale_tensor =
+            ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
                                     theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
 
-        // position
-        aclTensor* acl_position_tensor = ggml_cann_create_tensor(
-            src1->data, ggml_cann_type_mapping(src1->type),
-            ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
-
-        // power * position
-        int64_t theta_length = theta_scale_length * position_length;
-        ggml_cann_pool_alloc theta_allocator(ctx.pool(),
-                                            theta_length * sizeof(float_t));
-        void* theta_buffer = theta_allocator.get();
-
-        aclTensor* acl_theta_tensor =
-            ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
-                                    theta_ne, theta_nb, GGML_MAX_DIMS);
-        aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
-                acl_theta_tensor);
-
-        // sin/cos
-        aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
-            ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
-            GGML_MAX_DIMS, ACL_FORMAT_ND);
-        aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
-
-        aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
-            ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
-            GGML_MAX_DIMS, ACL_FORMAT_ND);
-        aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
-
-        // release
-        ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
-            acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
-    }
-
+    // position
+    aclTensor* acl_position_tensor = ggml_cann_create_tensor(
+        src1->data, ggml_cann_type_mapping(src1->type),
+        ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
+
+    // power * position
+    int64_t theta_length = theta_scale_length * position_length;
+    ggml_cann_pool_alloc theta_allocator(ctx.pool(),
+                                        theta_length * sizeof(float_t));
+    void* theta_buffer = theta_allocator.get();
+
+    aclTensor* acl_theta_tensor =
+        ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
+                                theta_ne, theta_nb, GGML_MAX_DIMS);
+    aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
+            acl_theta_tensor);
+
+    // sin/cos
+    ggml_cann_pool_alloc sin_allocator(ctx.pool(),
+                                    theta_length * sizeof(float_t));
+    void* sin_buffer = sin_allocator.get();
     aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
-            ctx.sin_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
-            GGML_MAX_DIMS, ACL_FORMAT_ND);
+        sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
+        GGML_MAX_DIMS, ACL_FORMAT_ND);
+    aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
+
+    ggml_cann_pool_alloc cos_allocator(ctx.pool(),
+                                    theta_length * sizeof(float_t));
+    void* cos_buffer = cos_allocator.get();
     aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
-            ctx.cos_ptr, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
-            GGML_MAX_DIMS, ACL_FORMAT_ND);
+        cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
+        GGML_MAX_DIMS, ACL_FORMAT_ND);
+    aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
 
     // attn_factor
     if (attn_factor != 1) {
@@ -2365,6 +2397,19 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
         aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
     }
 
+    int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
+    size_t sin_reshape_nb[GGML_MAX_DIMS];
+    sin_reshape_nb[0] = sizeof(float_t);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
+    }
+    aclTensor* acl_sin_repeat_tensor =
+        ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
+                                sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
+    aclTensor* acl_cos_repeat_tensor =
+        ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
+                                sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
+
     // repeat
     if (is_neox) {
         int64_t repeatsArray[] = {1, 1, 1, 2};
@@ -2380,8 +2425,9 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
                                 num_repeats, output_size);
     }
 
-    // release
-    ggml_cann_release_resources(ctx, acl_sin_tensor, acl_cos_tensor);
+    ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
+        acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
+        acl_cos_repeat_tensor);
 }
 
 #ifdef __cplusplus
@@ -2435,13 +2481,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
 
-    // init cos/sin cache
-    ggml_cann_pool_alloc sin_allocator(
-        ctx.pool(), ne00 * ne02 * sizeof(float_t));
-    ggml_cann_pool_alloc cos_allocator(
-        ctx.pool(), ne00 * ne02 * sizeof(float_t));
-    void* sin_buffer = sin_allocator.get();
-    void* cos_buffer = cos_allocator.get();
+    // init ctx.rope_cos/rope_sin cache
+    aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
 
     int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
     size_t sin_reshape_nb[GGML_MAX_DIMS];
@@ -2450,13 +2491,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
     }
     aclTensor* acl_sin_reshape_tensor =
-        ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
+        ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
                                 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
     aclTensor* acl_cos_reshape_tensor =
-        ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
+        ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
                                 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
-    aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
-                    theta_scale, freq_scale, attn_factor, is_neox);
 
     aclTensor* acl_src = ggml_cann_create_tensor(src0);
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
index 5858bd3f6a1970bf5bf8f5efb8a79e3e7757a4eb..33794062f565d5b7473f946a4206d6002c407463 100755 (executable)
@@ -368,10 +368,6 @@ struct ggml_backend_cann_context {
     std::string name;                /**< Name of the device. */
     std::string description;         /**< Description of the device. */
     aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
-    void* init_ptr = nullptr;
-    void* sin_ptr = nullptr;
-    void* cos_ptr = nullptr;
-    int64_t max_prompt_length = 65536;
 #ifdef USE_ACL_GRAPH
     /// Cached CANN ACL graph used for executing the current ggml computation graph.
     std::unique_ptr<ggml_cann_graph> cann_graph;
@@ -379,6 +375,12 @@ struct ggml_backend_cann_context {
     cann_task_queue task_queue;
     bool async_mode;
     bool support_set_rows;
+    // Rope Cache
+    void* rope_init_ptr = nullptr;
+    void* rope_sin_ptr = nullptr;
+    void* rope_cos_ptr = nullptr;
+    int64_t max_prompt_length = 0;
+    // Constant Pool
     void* f32_zero_cache = nullptr;
     void* f32_one_cache = nullptr;
     int64_t f32_zero_cache_element = 0;
@@ -422,14 +424,20 @@ struct ggml_backend_cann_context {
                 ACL_CHECK(aclrtDestroyStream(streams[i]));
             }
         }
-        if(init_ptr != nullptr) {
-            ACL_CHECK(aclrtFree(init_ptr));
+        if(rope_init_ptr != nullptr) {
+            ACL_CHECK(aclrtFree(rope_init_ptr));
         }
-        if(sin_ptr != nullptr) {
-            ACL_CHECK(aclrtFree(sin_ptr));
+        if(rope_sin_ptr != nullptr) {
+            ACL_CHECK(aclrtFree(rope_sin_ptr));
         }
-        if(cos_ptr != nullptr) {
-            ACL_CHECK(aclrtFree(cos_ptr));
+        if(rope_cos_ptr != nullptr) {
+            ACL_CHECK(aclrtFree(rope_cos_ptr));
+        }
+        if(f32_zero_cache != nullptr) {
+            ACL_CHECK(aclrtFree(f32_zero_cache));
+        }
+        if(f32_one_cache != nullptr) {
+            ACL_CHECK(aclrtFree(f32_one_cache));
         }
     }