bool use_cann_graph = true;
bool cann_graph_update_required = false;
+ static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
+ if (!prefill_use_graph) {
+ // Do not use acl_graph for prefill.
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor * node = cgraph->nodes[i];
+ // TODO: Optimize here. Currently, we can only
+ // get seq_len by FA's input.
+ if (node->op == GGML_OP_FLASH_ATTN_EXT) {
+ // Q -> src[0], shape: [B, S, N, D]
+ use_cann_graph = (node->src[0]->ne[1] == 1);
+ break;
+ }
+ }
+ }
+
if (!cann_ctx->acl_graph_mode) {
use_cann_graph = false;
}