]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cuda : CUDA Graph Compute Function Refactor (precursor for performance improvements...
authorAndreas Kieslinger <redacted>
Mon, 13 Jan 2025 15:45:53 +0000 (16:45 +0100)
committerGeorgi Gerganov <redacted>
Tue, 14 Jan 2025 08:38:01 +0000 (10:38 +0200)
* Refactor: Moves cuda graph executable update step to separate function.

* Refactor: Moves cuda graph update check to separate function.

* Refactor: Moves cuda graph maintenance (update or adjusting copy parameters) to separate function for improved readability.

* Fix: Adds missing reference to maintain_cuda_graph() definition.

* Refactor: Improves structure and abstractions by moving CUDA graph evaluation and capture to its own function.

* Refactor: Moves node graph checks and copy ops into individual function for improved readability.

* Refactor: Removes code permanently excluded from compilation to increase readability.

* Style: Adds missing newline

* Style: Consolidates several neighboring '#ifdef USE_CUDA_GRAPH' into a single one

* Refactor: Makes 'cuda_graph_update_required' a local variable

* remove double lines between functions

---------

Co-authored-by: slaren <redacted>
ggml/src/ggml-cuda/ggml-cuda.cu

index 8476ee1bca50c8c0cd8645f0238e7a4ef8f691d5..1dac397c4b0836b2111f0c2f910c7edfff695d53 100644 (file)
@@ -2289,6 +2289,66 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
 }
 
 #ifdef USE_CUDA_GRAPH
+static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
+    std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
+
+    // Loop over nodes in GGML graph to obtain info needed for CUDA graph
+    cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        ggml_tensor * node = cgraph->nodes[i];
+
+        if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+            continue;
+        }
+
+        if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
+            use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
+#endif
+        }
+
+        if (node->op == GGML_OP_MUL_MAT_ID) {
+            use_cuda_graph = false; // This node type is not supported by CUDA graph capture
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+#endif
+        }
+
+        if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
+            // disable CUDA graphs for batch size > 1 for now.
+            // Changes in batch size or context size can cause changes to the grid size of some kernels.
+            use_cuda_graph = false;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+#endif
+        }
+
+        if (node->op == GGML_OP_CPY) {
+            // store the copy op parameter which changes with each token.
+            cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+            // store a pointer to each copy op CUDA kernel to identify it later
+            void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+            if (!ptr) {
+                use_cuda_graph = false;
+#ifndef NDEBUG
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
+#endif
+            } else {
+                if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
+                    ggml_cuda_cpy_fn_ptrs.push_back(ptr);
+                }
+            }
+        }
+
+        if (!use_cuda_graph) {
+            break;
+        }
+    }
+
+    return use_cuda_graph;
+}
+
 static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
     graph_node_properties->node_address = node->data;
     graph_node_properties->node_op = node->op;
@@ -2339,149 +2399,105 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
 
     return true;
 }
-#endif
 
-static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
-    ggml_cuda_set_device(cuda_ctx->device);
+static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
 
-#ifdef USE_CUDA_GRAPH
-    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+    if (cuda_graph_update_required) {
+        // Extract nodes from graph
+        // First call with null argument gets number of nodes in graph
+        CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
+        // Subsequent call with non-null argument gets nodes
+        cuda_ctx->cuda_graph->nodes.clear();
+        cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+        cuda_ctx->cuda_graph->params.clear();
+        cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
+        if (cuda_ctx->cuda_graph->num_nodes > 0) {
+            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
 
-    // Objects required for CUDA Graph
-    if (cuda_ctx->cuda_graph == nullptr) {
-        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
+            // Loop over nodes, and extract kernel parameters from each node
+            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+                cudaGraphNodeType node_type;
+                CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
+                if (node_type == cudaGraphNodeTypeKernel) {
+                    cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
+                    if (stat == cudaErrorInvalidDeviceFunction) {
+                        // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
+                        // We don't need to update blas nodes, so clear error and move on.
+                        cudaGetLastError();
+                    } else {
+                        GGML_ASSERT(stat == cudaSuccess);
+                    }
+                }
+            }
+        }
+    } else {
+        // One of the arguments to the copy kernel is updated for each token, hence we need to
+        // replace that argument with the updated value in the CUDA graph
+        // on update steps, the live parameters will already be captured
+        int k = 0;
+        for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+            if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
+                char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
+                cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
+                CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
+            }
+        }
     }
+}
+
+static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
 
-    bool use_cuda_graph = true;
     bool cuda_graph_update_required = false;
-    // vector of pointers to CUDA cpy kernels, which are required to identify
-    // kernel parameters which need updated in the graph for each token
-    std::vector<void *> ggml_cuda_cpy_fn_ptrs;
 
-    if (cuda_ctx->cuda_graph->graph == nullptr) {
-        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
-            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
-#endif
-        }
+    if (cuda_ctx->cuda_graph->instance == nullptr) {
+        cuda_graph_update_required = true;
     }
 
-    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
-    // or previous graph capture failure.
-    // Also disable for multi-gpu for now. TO DO investigate
-    if (disable_cuda_graphs_due_to_env
-        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
-        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
-        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
-        use_cuda_graph = false;
+    // Check if the graph size has changed
+    if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+        cuda_graph_update_required = true;
+        cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
     }
 
-    if (use_cuda_graph) {
-        if (cuda_ctx->cuda_graph->instance == nullptr) {
-            cuda_graph_update_required = true;
+    // Loop over nodes in GGML graph to determine if CUDA graph update is required
+    // and store properties to allow this comparison for the next token
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        bool has_matching_properties = true;
+        if (!cuda_graph_update_required) {
+            has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
         }
-
-        // Check if the graph size has changed
-        if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+        if (!has_matching_properties) {
             cuda_graph_update_required = true;
-            cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
-        }
-
-        // Loop over nodes in GGML graph to determine if CUDA graph update is required
-        // and store properties to allow this comparison for the next token
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            bool has_matching_properties = true;
-            if (!cuda_graph_update_required) {
-                has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
-            }
-            if (!has_matching_properties) {
-                cuda_graph_update_required = true;
-            }
-            set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
         }
+        set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+    }
 
-        // Loop over nodes in GGML graph to obtain info needed for CUDA graph
-        cuda_ctx->cuda_graph->updated_kernel_arg.clear();
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            ggml_tensor * node = cgraph->nodes[i];
-
-            if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
-                continue;
-            }
-
-            if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
-                use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
-#endif
-            }
-
-            if (node->op == GGML_OP_MUL_MAT_ID) {
-                use_cuda_graph = false; // This node type is not supported by CUDA graph capture
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
-#endif
-            }
-
-            if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
-                // disable CUDA graphs for batch size > 1 for now.
-                // Changes in batch size or context size can cause changes to the grid size of some kernels.
-                use_cuda_graph = false;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
-#endif
-            }
-
-            if (node->op == GGML_OP_CPY) {
-                // store the copy op parameter which changes with each token.
-                cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
-                // store a pointer to each copy op CUDA kernel to identify it later
-                void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
-                if (!ptr) {
-                    use_cuda_graph = false;
-#ifndef NDEBUG
-                    GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
-#endif
-                } else {
-                    if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
-                        ggml_cuda_cpy_fn_ptrs.push_back(ptr);
-                    }
-                }
-            }
-
-            if (!use_cuda_graph) {
-                break;
-            }
-        }
+    return cuda_graph_update_required;
+}
 
-        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
-        if (use_cuda_graph && cuda_graph_update_required) {
-            cuda_ctx->cuda_graph->number_consecutive_updates++;
-        } else {
-            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
-        }
+static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
 
-        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
-            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
+    cudaGraphExecUpdateResultInfo result_info;
+    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+    if (stat == cudaErrorGraphExecUpdateFailure) {
 #ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
+        GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
 #endif
-        }
-    }
-
-    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
-        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+        // The pre-existing graph exec cannot be updated due to violated constraints
+        // so instead clear error and re-instantiate
+        cudaGetLastError();
+        CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
+        cuda_ctx->cuda_graph->instance = nullptr;
+        CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+    } else {
+        GGML_ASSERT(stat == cudaSuccess);
     }
+}
+#endif
 
-#else
-    bool use_cuda_graph = false;
-    bool cuda_graph_update_required = false;
-#endif // USE_CUDA_GRAPH
-
-    bool graph_evaluated_or_captured = false;
+static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
+   [[maybe_unused]] std::vector<void *> & ggml_cuda_cpy_fn_ptrs,  bool & graph_evaluated_or_captured, bool & use_cuda_graph,
+    bool & cuda_graph_update_required) {
 
     while (!graph_evaluated_or_captured) {
         // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
@@ -2519,19 +2535,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
                 CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
                 cuda_ctx->cuda_graph->graph = nullptr;
             }
-            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
 
-#if 0
-            if (disable_cuda_graphs_due_to_failed_capture) {
-                use_cuda_graph = false;
-                cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
-#endif
-            } else {
-                graph_evaluated_or_captured = true; // CUDA graph has been captured
-            }
-#endif
+            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
             graph_evaluated_or_captured = true; // CUDA graph has been captured
         } else {
             graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
@@ -2544,72 +2549,91 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         }
 
         // Perform update to graph (if required for this token), and change copy parameter (required for every token)
+        maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
 
-        if (cuda_graph_update_required) {
-            // Extract nodes from graph
-            // First call with null argument gets number of nodes in graph
-            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
-            // Subsequent call with non-null argument gets nodes
-            cuda_ctx->cuda_graph->nodes.clear();
-            cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
-            cuda_ctx->cuda_graph->params.clear();
-            cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
-            if (cuda_ctx->cuda_graph->num_nodes > 0) {
-                CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
-
-                // Loop over nodes, and extract kernel parameters from each node
-                for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                    cudaGraphNodeType node_type;
-                    CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
-                    if (node_type == cudaGraphNodeTypeKernel) {
-                        cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
-                        if (stat == cudaErrorInvalidDeviceFunction) {
-                            // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
-                            // We don't need to update blas nodes, so clear error and move on.
-                            cudaGetLastError();
-                        } else {
-                            GGML_ASSERT(stat == cudaSuccess);
-                        }
-                    }
-                }
-            }
+        // Update graph executable
+        update_cuda_graph_executable(cuda_ctx);
+
+        // Launch graph
+        CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+#else
+        graph_evaluated_or_captured = true;
+#endif  // USE_CUDA_GRAPH
+    }
+}
+
+static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    ggml_cuda_set_device(cuda_ctx->device);
+
+    // vector of pointers to CUDA cpy kernels, which are required to identify
+    // kernel parameters which need updated in the graph for each token
+    std::vector<void *> ggml_cuda_cpy_fn_ptrs;
+
+#ifdef USE_CUDA_GRAPH
+    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+
+    // Objects required for CUDA Graph
+    if (cuda_ctx->cuda_graph == nullptr) {
+        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
+    }
+
+    bool use_cuda_graph = true;
+    bool cuda_graph_update_required = false;
+
+    if (cuda_ctx->cuda_graph->graph == nullptr) {
+        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
+            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+#endif
         }
+    }
 
-        // One of the arguments to the copy kernel is updated for each token, hence we need to
-        // replace that argument with the updated value in the CUDA graph
-        if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
-            int k = 0;
-            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
-                    char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
-                    cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
-                    CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
-                }
-            }
+    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
+    // or previous graph capture failure.
+    // Also disable for multi-gpu for now. TO DO investigate
+    if (disable_cuda_graphs_due_to_env
+        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
+        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
+        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
+        use_cuda_graph = false;
+    }
+
+    if (use_cuda_graph) {
+        cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
+
+        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
+                             ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
+
+        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
+        if (use_cuda_graph && cuda_graph_update_required) {
+            cuda_ctx->cuda_graph->number_consecutive_updates++;
+        } else {
+            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
         }
 
-        // Update graph executable
-        cudaGraphExecUpdateResultInfo result_info;
-        cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
-        if (stat == cudaErrorGraphExecUpdateFailure) {
+        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
+            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
 #ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
 #endif
-            // The pre-existing graph exec cannot be updated due to violated constraints
-            // so instead clear error and re-instantiate
-            cudaGetLastError();
-            CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
-            cuda_ctx->cuda_graph->instance = nullptr;
-            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
-        } else {
-            GGML_ASSERT(stat == cudaSuccess);
         }
-        // Launch graph
-        CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+    }
+
+    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
+        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+    }
+
 #else
-        graph_evaluated_or_captured = true;
+    bool use_cuda_graph = false;
+    bool cuda_graph_update_required = false;
 #endif // USE_CUDA_GRAPH
-    }
+
+    bool graph_evaluated_or_captured = false;
+
+    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
 
     return GGML_STATUS_SUCCESS;
 }