]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Allow number of nodes in CUDA graph to change (llama/7738)
authoragray3 <redacted>
Tue, 4 Jun 2024 20:06:49 +0000 (21:06 +0100)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
Previously the code would have failed to cope in the case that the
number of nodes changes in an existing CUDA graph. This fixes the
issue by removing an unnecessary conditional.

src/ggml-cuda.cu

index daaa0cd6a5473a23426a729d2549aba4cc2fc49c..c81c6a0d783be118cdea5bdb64d6eb17fa939a33 100644 (file)
@@ -2702,10 +2702,8 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
 
         if (cuda_graph_update_required) {
             // Extract nodes from graph
-            if (cuda_ctx->cuda_graph->num_nodes == 0) {
-                // First call with null argument gets number of nodes in graph
-                CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
-            }
+            // First call with null argument gets number of nodes in graph
+            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
             // Subsequent call with non-null argument gets nodes
             cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
             cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);