]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CANN: fix RoPE cache issue on multi-device (#15629)
authorhipudding <redacted>
Mon, 1 Sep 2025 00:57:00 +0000 (08:57 +0800)
committerGitHub <redacted>
Mon, 1 Sep 2025 00:57:00 +0000 (08:57 +0800)
* CANN: fix RoPE cache issue on multi-device

RoPE cache only needs to be computed once per token.
However, in multi-device scenarios, not every device starts
computation from layer 0, which may lead to unallocated memory
issues and precision errors.

This commit records the first layer of each device to avoid
the above issues.

* CANN: Optimize first-layer detection method

* CANN: Remove trailing whitespace

* CANN: Only cache the data that can be determined as unchanged through the parameters.

* CANN: Update function comment

ggml/src/ggml-cann/aclnn_ops.cpp
ggml/src/ggml-cann/common.h
ggml/src/ggml-cann/ggml-cann.cpp

index c42871c575822817c9d18fd2314352525e64347f..1f1d489ffa4113ec622a24b2eaf849fee7da250c 100755 (executable)
@@ -964,8 +964,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     }
     aclTensor* acl_gamma = get_f32_cache_acl_tensor(
         ctx,
-        &ctx.f32_one_cache,
-        ctx.f32_one_cache_element,
+        &ctx.rms_norm_one_tensor_cache.cache,
+        ctx.rms_norm_one_tensor_cache.size,
         src->ne,
         acl_gamma_nb,
         1,        // dims
@@ -980,8 +980,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     }
     aclTensor* acl_rstd = get_f32_cache_acl_tensor(
         ctx,
-        &ctx.f32_zero_cache,
-        ctx.f32_zero_cache_element,
+        &ctx.rms_norm_zero_tensor_cache.cache,
+        ctx.rms_norm_zero_tensor_cache.size,
         src->ne,
         acl_rstd_nb,
         GGML_MAX_DIMS,
@@ -2248,43 +2248,31 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
  *   5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
  *   6. Expand sin/cos values by repeat or repeat_interleave depending
  *      on whether @param is_neox is enabled.
- *   7. Store the computed values into persistent buffers
- *      (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
- *
- * @param ctx         The CANN backend context, holding memory pool,
- *                    stream, and persistent buffers for rope init/cache.
- * @param dst         The destination ggml_tensor whose computation
- *                    depends on the cached RoPE values (usually Qcur/Kcur).
- * @param theta_scale Scalar exponent base for computing theta scale values.
- * @param freq_scale  Frequency scaling factor, applied to theta scale.
- * @param attn_factor Attention scaling factor, applied to sin/cos.
- * @param is_neox     Whether to use Neox-style repeat strategy
- *                    (dim expansion vs repeat_interleave).
+ *
+ * @param ctx                The CANN backend context, holding memory pool,
+ *                           stream, and persistent buffers for rope init/cache.
+ * @param dst                The destination ggml_tensor whose computation
+ *                           depends on the RoPE values (usually Qcur/Kcur).
+ * @param sin_tensor_buffer  Pre-allocated buffer for storing repeated sin values.
+ * @param cos_tensor_buffer  Pre-allocated buffer for storing repeated cos values.
+ * @param theta_scale        Scalar exponent base for computing theta scale values.
+ * @param freq_scale         Frequency scaling factor, applied to theta scale.
+ * @param attn_factor        Attention scaling factor, applied to sin/cos.
+ * @param is_neox            Whether to use Neox-style repeat strategy
+ *                           (dim expansion vs repeat_interleave).
  */
 static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
+                             void* sin_tensor_buffer, void* cos_tensor_buffer,
                              float theta_scale, float freq_scale,
                              float attn_factor, bool is_neox) {
     // int sin/cos cache, cache has different repeat method depond on
     // @param.is_neox
-    bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
-    bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
-
-    // used for accuracy testing
-    bool is_attention = is_q || is_k;
-
-    // just compute in first layer in attention
-    bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
-    if(is_attention && !is_fisrt_layer) {
-        return;
-    }
 
     ggml_tensor* src0 = dst->src[0];  // input
     ggml_tensor* src1 = dst->src[1];  // position
     ggml_tensor* src2 = dst->src[2];  // freq_factors
 
-    GGML_TENSOR_BINARY_OP_LOCALS
-
-    int64_t theta_scale_length = ne00 / 2;
+    int64_t theta_scale_length = src0->ne[0] / 2;
     int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
     size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
                           theta_scale_length * sizeof(float_t)};
@@ -2302,21 +2290,32 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
         theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
     }
 
-    // init theta scale, just one time
-    if(ctx.rope_init_ptr == nullptr || !is_attention) {
-        // theta_scale arange, [0,1,...,ne00/2 - 1]
-        if(ctx.rope_init_ptr != nullptr){
-            ACL_CHECK(aclrtFree(ctx.rope_init_ptr));
+    // theta_scale arange, [0,1,...,ne00/2 - 1]
+    aclTensor* acl_theta_scale_tensor = nullptr;
+    // cache theta scale
+    if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
+        // theta_scale and freq_scale should not change during the current token inference process,
+        // so we can directly use == here instead of comparing the absolute difference.
+        ctx.rope_cache.theta_scale != theta_scale ||
+        ctx.rope_cache.freq_scale != freq_scale) {
+
+        ctx.rope_cache.theta_scale_length = theta_scale_length;
+        ctx.rope_cache.theta_scale = theta_scale;
+        ctx.rope_cache.freq_scale = freq_scale;
+
+        if (ctx.rope_cache.theta_scale_cache != nullptr) {
+            ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
         }
-        ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
+        ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
 
-        aclTensor* acl_theta_scale_tensor =
-            ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
+        acl_theta_scale_tensor =
+            ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
                                     theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
+
         float start = 0;
         float step = 1;
-        float stop = ne00 / 2;
-        float n_elements = ne00 / 2;
+        float stop = theta_scale_length;
+        float n_elements = theta_scale_length;
         aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
 
         // power
@@ -2328,35 +2327,30 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
         if (freq_scale != 1) {
             aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
         }
-
-        // freq_factors
-        if (src2) {
-            aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
-                src2->data, ggml_cann_type_mapping(src2->type),
-                ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
-            aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
-            ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
-        }
-        // release
-        ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
+        ggml_cann_release_resources(ctx, acl_theta_scale);
+    } else {
+        // use cache
+        acl_theta_scale_tensor =
+            ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
+                                    theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
     }
 
-    // init sin_repeat && cos_repeat, one token just init in 0 layer
-    if(position_length > ctx.max_prompt_length) {
-        ctx.max_prompt_length = position_length;
-        int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
-        if(ctx.rope_sin_ptr != nullptr) {
-            ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
-            ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
-        }
-        ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
-        ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
+    ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
+    // freq_factors
+    if (src2) {
+        freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float_t));
+        void* freq_fac_res_ptr = freq_fac_res_allocator.get();
+        aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
+            src2->data, ggml_cann_type_mapping(src2->type),
+            ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
+        aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor(
+            freq_fac_res_ptr, ACL_FLOAT, sizeof(float_t),
+            theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
+        aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
+        std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
+        ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
     }
 
-    aclTensor* acl_theta_scale_tensor =
-            ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
-                                    theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
-
     // position
     aclTensor* acl_position_tensor = ggml_cann_create_tensor(
         src1->data, ggml_cann_type_mapping(src1->type),
@@ -2397,17 +2391,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
         aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
     }
 
-    int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
+    int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
     size_t sin_reshape_nb[GGML_MAX_DIMS];
     sin_reshape_nb[0] = sizeof(float_t);
     for (int i = 1; i < GGML_MAX_DIMS; i++) {
         sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
     }
     aclTensor* acl_sin_repeat_tensor =
-        ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
+        ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
                                 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
     aclTensor* acl_cos_repeat_tensor =
-        ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
+        ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
                                 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
 
     // repeat
@@ -2449,6 +2443,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     // TODO: use ascendc
     // Only test with LLAMA model.
     ggml_tensor* src0 = dst->src[0];  // input
+    ggml_tensor* src1 = dst->src[1];
 
     // param
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2481,8 +2476,16 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
     const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
 
+    // sin/cos tensor length.
+    int64_t repeat_theta_length = src0->ne[0] * src1->ne[0];
+    ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
+    ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
+    void *sin_tensor_buffer = sin_tensor_allocator.get();
+    void *cos_tensor_buffer = cos_tensor_allocator.get();
+
     // init ctx.rope_cos/rope_sin cache
-    aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
+    aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer,
+                    theta_scale, freq_scale, attn_factor, is_neox);
 
     int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
     size_t sin_reshape_nb[GGML_MAX_DIMS];
@@ -2491,10 +2494,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
     }
     aclTensor* acl_sin_reshape_tensor =
-        ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t),
+        ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
                                 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
     aclTensor* acl_cos_reshape_tensor =
-        ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t),
+        ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
                                 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
 
     aclTensor* acl_src = ggml_cann_create_tensor(src0);
index 88cc3f481ed3832c191d1a5e8f156677d23669f7..f71aa9d1de65c1e5332dbeec2357373decc6c14a 100755 (executable)
@@ -360,6 +360,30 @@ struct ggml_cann_graph {
 };
 #endif  // USE_ACL_GRAPH
 
+struct ggml_cann_rope_cache {
+    ~ggml_cann_rope_cache() {
+        if(theta_scale_cache != nullptr) {
+            ACL_CHECK(aclrtFree(theta_scale_cache));
+        }
+    }
+
+    void* theta_scale_cache = nullptr;
+    int64_t theta_scale_length = 0;
+    float theta_scale = 0.0f;
+    float freq_scale = 0.0f;
+};
+
+struct ggml_cann_tensor_cache {
+    ~ggml_cann_tensor_cache() {
+        if(cache != nullptr) {
+            ACL_CHECK(aclrtFree(cache));
+        }
+    }
+
+    void* cache = nullptr;
+    int64_t size = 0;
+};
+
 /**
  * @brief Context for managing CANN backend operations.
  */
@@ -375,15 +399,11 @@ struct ggml_backend_cann_context {
     cann_task_queue task_queue;
     bool async_mode;
     // Rope Cache
-    void* rope_init_ptr = nullptr;
-    void* rope_sin_ptr = nullptr;
-    void* rope_cos_ptr = nullptr;
-    int64_t max_prompt_length = 0;
+    ggml_cann_rope_cache rope_cache;
     // Constant Pool
-    void* f32_zero_cache = nullptr;
-    void* f32_one_cache = nullptr;
-    int64_t f32_zero_cache_element = 0;
-    int64_t f32_one_cache_element = 0;
+    ggml_cann_tensor_cache rms_norm_one_tensor_cache;
+    ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
+
 
     aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
 
@@ -415,21 +435,6 @@ struct ggml_backend_cann_context {
                 ACL_CHECK(aclrtDestroyStream(streams[i]));
             }
         }
-        if(rope_init_ptr != nullptr) {
-            ACL_CHECK(aclrtFree(rope_init_ptr));
-        }
-        if(rope_sin_ptr != nullptr) {
-            ACL_CHECK(aclrtFree(rope_sin_ptr));
-        }
-        if(rope_cos_ptr != nullptr) {
-            ACL_CHECK(aclrtFree(rope_cos_ptr));
-        }
-        if(f32_zero_cache != nullptr) {
-            ACL_CHECK(aclrtFree(f32_zero_cache));
-        }
-        if(f32_one_cache != nullptr) {
-            ACL_CHECK(aclrtFree(f32_one_cache));
-        }
     }
 
     /**
index 7b3aca9db97578af5ac0c00f6d6bfc88ddeba223..15ea85e2737ed1085f172e9f0dd77fb9eb15862d 100755 (executable)
@@ -2247,6 +2247,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
         (ggml_backend_cann_context*)backend->context;
     ggml_cann_set_device(cann_ctx->device);
     release_nz_workspace();
+
 #ifdef USE_ACL_GRAPH
     bool use_cann_graph = true;
     bool cann_graph_update_required = false;