for (int i = 0; i < GGML_MAX_SRC; i++) {
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
}
+ memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
return false;
}
}
+
+ if (node->op == GGML_OP_SCALE &&
+ memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
+ return false;
+ }
+
return true;
}
// First call with null argument gets number of nodes in graph
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
// Subsequent call with non-null argument gets nodes
+ cuda_ctx->cuda_graph->nodes.clear();
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+ cuda_ctx->cuda_graph->params.clear();
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
if (cuda_ctx->cuda_graph->num_nodes > 0) {
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
+ int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};
struct ggml_cuda_graph {