bool use_cuda_graph = true;
bool cuda_graph_update_required = false;
- // pointer to CUDA cpy kernel, which is required to identify
+ // vector of pointers to CUDA cpy kernels, which are required to identify
// kernel parameters which need updated in the graph for each token
- void * ggml_cuda_cpy_fn_ptr = nullptr;
+ std::vector<void *> ggml_cuda_cpy_fn_ptrs;
if (cuda_ctx->cuda_graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
if (node->op == GGML_OP_CPY) {
// store the copy op parameter which changes with each token.
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
- if (ggml_cuda_cpy_fn_ptr == nullptr) {
- // store a pointer to the copy op CUDA kernel to identify it later
- ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+ // store a pointer to each copy op CUDA kernel to identify it later
+ void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+ if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
+ ggml_cuda_cpy_fn_ptrs.push_back(ptr);
}
}
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
int k = 0;
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
- if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
+ if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));