]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Allow multiple copy function pointers for CUDA graph kernel param updates (llama...
authoragray3 <redacted>
Mon, 27 May 2024 17:33:42 +0000 (18:33 +0100)
committerGeorgi Gerganov <redacted>
Wed, 29 May 2024 10:16:38 +0000 (13:16 +0300)
CUDA graphs require parameter updates to kernels associated with
GGML_OP_CPY nodes. Previously the implementation only checked for a
single CUDA kernel in such nodes, but this caused a bug in cases where
2 such kernels exist. This fixes the issue by using a vector to allow
multiple function pointers to be stored and checked against.

Fixes #7942

src/ggml-cuda.cu

index b82167cbf7227ba10374f6318498f0dc3c7f869d..2a90ee55c69a059576cb2eaafe593e5dcfd94c93 100644 (file)
@@ -2510,9 +2510,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
 
     bool use_cuda_graph = true;
     bool cuda_graph_update_required = false;
-    // pointer to CUDA cpy kernel, which is required to identify
+    // vector of pointers to CUDA cpy kernels, which are required to identify
     // kernel parameters which need updated in the graph for each token
-    void * ggml_cuda_cpy_fn_ptr = nullptr;
+    std::vector<void *> ggml_cuda_cpy_fn_ptrs;
 
     if (cuda_ctx->cuda_graph->graph == nullptr) {
         if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
@@ -2588,9 +2588,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
             if (node->op == GGML_OP_CPY) {
                 // store the copy op parameter which changes with each token.
                 cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
-                if (ggml_cuda_cpy_fn_ptr == nullptr) {
-                    // store a pointer to the copy op CUDA kernel to identify it later
-                    ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+                // store a pointer to each copy op CUDA kernel to identify it later
+                void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+                if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
+                    ggml_cuda_cpy_fn_ptrs.push_back(ptr);
                 }
             }
 
@@ -2720,7 +2721,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
         if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
             int k = 0;
             for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
+                if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
                     char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
                     cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
                     CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));