]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CANN: add RoPE cache preload before ACL graph capture (llama/20747)
authorChenguang Li <redacted>
Mon, 23 Mar 2026 07:24:06 +0000 (15:24 +0800)
committerGeorgi Gerganov <redacted>
Sat, 28 Mar 2026 11:39:09 +0000 (13:39 +0200)
ACL graph capture disallows host-to-device memcpy and device memory
malloc/free on the captured stream. Pre-load the RoPE cache before
capture so that:
- Host-to-device copies and allocations run on the non-captured stream
- Cache metadata is populated and memory pool is warmed up
- During capture, only on-device computations are recorded; host-side
  and allocation branches are skipped

src/ggml-cann/aclnn_ops.cpp
src/ggml-cann/aclnn_ops.h
src/ggml-cann/common.h
src/ggml-cann/ggml-cann.cpp

index b45774dde343a84719de2d5244e5ad0923f3e88c..adb4d68e8687e7fe1b3a82e16221f003e25ce98e 100644 (file)
@@ -3011,6 +3011,58 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     }
 }
 
+void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * src0 = dst->src[0];
+
+    float     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    int       sections[4];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+    memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+    memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+    memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
+
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
+
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    bool       is_neox    = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_imrope  = mode == GGML_ROPE_TYPE_IMROPE;
+    const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
+    const bool is_vision  = mode == GGML_ROPE_TYPE_VISION;
+
+    if (is_imrope || mrope_used) {
+        is_neox = true;
+    }
+
+    int64_t rope_dims = n_dims;
+    if (is_vision) {
+        rope_dims = src0->ne[0];
+    }
+
+    // Run the full cache init on the non-captured stream.  This performs all
+    // host-to-device memcpy, aclrtMalloc/Free, and on-device computations
+    // so that the memory pool is warmed up and cache metadata is populated.
+    aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,
+                          mrope_used, is_imrope, is_vision, rope_dims);
+
+    // Reset `cached` so that during graph capture the on-device computations
+    // (sin/cos, position multiply, repeat, etc.) still execute and get recorded
+    // into the captured graph.  The cache metadata (theta_scale_length,
+    // theta_scale, sections, position_length, etc.) remains set, which causes
+    // all host-to-device copy and malloc/free branches to be skipped.
+    ctx.rope_cache.cached = false;
+}
+
 void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     ggml_tensor * src0 = dst->src[0];
 
index 3effa1c289c0f07048fa5871f329f91ddf3d1bc6..7f5ba4d3302d1a4e8a8ad239330f49c8727e363b 100644 (file)
@@ -543,6 +543,21 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
  */
 void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst);
 
+/**
+ * @brief Pre-load the RoPE cache before ACL graph capture.
+ *
+ * This function must be called outside of graph capture to perform
+ * host-to-device memory copies and device memory allocations that are
+ * not allowed on a captured stream.  After pre-loading, the rope cache
+ * metadata is updated so that the subsequent call to
+ * aclnn_rope_cache_init (inside graph capture) skips these operations
+ * and only records the on-device computations into the captured graph.
+ *
+ * @param ctx  CANN backend context.
+ * @param dst  A ROPE destination tensor from the computation graph.
+ */
+void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst);
+
 /**
  * @brief   Computes the index of the maximum value along the specified dimension
  *          of a ggml tensor using the CANN backend.
index 0120f0dfd1e6bc125701229af6f451ef0739e64c..5f960548cd2e2f5b47bb5548b5bf816e2a93c296 100644 (file)
@@ -277,7 +277,7 @@ struct ggml_graph_node_properties {
             }
         }
 
-        if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU{
+        if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU || node->op == GGML_OP_ROPE){
             return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
         }
         return true;
index 2f9c350789c56a2faa88f583fc340a64c04a73cb..6f26e91e046525cf303f7bd30bf0e2a920fce85f 100644 (file)
@@ -2225,6 +2225,19 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
             // If no matching graph is found, add a new ACL graph.
             ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
             cann_ctx->graph_lru_cache.push(new_graph);
+
+            // Pre-load rope cache before graph capture.  During capture the
+            // stream cannot perform host-to-device memcpy or device memory
+            // malloc/free.  Running the full cache init now populates the
+            // cache metadata so these branches are skipped during capture,
+            // while also warming up the memory pool.
+            for (int i = 0; i < cgraph->n_nodes; i++) {
+                ggml_tensor * node = cgraph->nodes[i];
+                if (node->op == GGML_OP_ROPE) {
+                    ggml_cann_rope_cache_preload(*cann_ctx, node);
+                    break;
+                }
+            }
         }
     }
 #else