]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cuda: fixes for concurrent streams (#18496)
authorAman Gupta <redacted>
Sat, 3 Jan 2026 15:15:01 +0000 (23:15 +0800)
committerGitHub <redacted>
Sat, 3 Jan 2026 15:15:01 +0000 (23:15 +0800)
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/ggml-cuda.cu

index 62e618850bf3e171d0cef0cc37a492e4a4b44163..302065ce9f935bea3975cd0ef8040f34c2144d3a 100644 (file)
@@ -1063,6 +1063,7 @@ struct ggml_cuda_graph {
     bool disable_due_to_too_many_updates = false;
     bool disable_due_to_failed_graph_capture = false;
     int number_consecutive_updates = 0;
+    bool cuda_graphs_enabled = false;
     std::vector<ggml_graph_node_properties> ggml_graph_properties;
 #endif
 };
index 84eccea3f7b6c9db041500966a03b2ed1c8ed2b4..f05d5562ba0483b56375bdc748b96a9dec77ba8f 100644 (file)
@@ -3253,6 +3253,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                     should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
                 }
             }
+
             if (should_launch_concurrent_events) {
                 // Restore original node order within each concurrent region to enable fusion within streams
 
@@ -3304,6 +3305,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                         cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
                     }
                 }
+            } else {
+                stream_ctx.concurrent_events.clear();
             }
 
             for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -3692,11 +3695,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
     }
 }
 
-static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
-    ggml_cuda_set_device(cuda_ctx->device);
-
+static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) {
 #ifdef USE_CUDA_GRAPH
     static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
 
@@ -3706,7 +3705,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
     }
 
     bool use_cuda_graph = true;
-    bool cuda_graph_update_required = false;
 
     if (cuda_ctx->cuda_graph->graph == nullptr) {
         if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
@@ -3727,6 +3725,29 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         use_cuda_graph = false;
     }
 
+    cuda_ctx->cuda_graph->cuda_graphs_enabled = use_cuda_graph;
+#else
+    bool use_cuda_graph = false;
+#endif // USE_CUDA_GRAPH
+
+    return use_cuda_graph;
+}
+
+static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
+
+    bool use_cuda_graph             = false;
+    bool cuda_graph_update_required = false;
+
+    // graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
+    // we call it here instead.
+#ifdef USE_CUDA_GRAPH
+    if (!cuda_ctx->cuda_graph) {
+        use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
+    } else {
+        use_cuda_graph = cuda_ctx->cuda_graph && cuda_ctx->cuda_graph->cuda_graphs_enabled;
+    }
+
     if (use_cuda_graph) {
         cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
 
@@ -3746,6 +3767,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
 #endif
         }
     }
+#endif // USE_CUDA_GRAPH
 
     if (use_cuda_graph && cuda_graph_update_required) {
         // Start CUDA graph capture
@@ -3757,11 +3779,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
     }
 
-#else
-    bool use_cuda_graph = false;
-    bool cuda_graph_update_required = false;
-#endif // USE_CUDA_GRAPH
-
     bool graph_evaluated_or_captured = false;
 
     evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
@@ -3797,8 +3814,10 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
 static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
 
+    const bool use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
+
     static bool enable_graph_optimization = [] {
-        const char * env = getenv("GGML_CUDA_GRAPH_OPT");
+        const char * env     = getenv("GGML_CUDA_GRAPH_OPT");
         return env != nullptr && atoi(env) == 1;
     }();
 
@@ -3806,12 +3825,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
         return;
     }
 
-    GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend");
-    GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes);
-
     ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
     stream_context.reset();
 
+    if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) {
+        return;
+    }
+
     // number of out-degrees for a particular node
     std::unordered_map<const ggml_tensor *, int> fan_out;
     // reverse mapping of node to index in the cgraph
@@ -3872,6 +3892,12 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
         if (count >= min_fan_out && count <= max_fan_out) {
             const int root_node_idx = node_indices[root_node];
 
+            // only optimize for attn_norm
+            // TODO: make this more generic
+            if (!strstr(root_node->name, "attn_norm")) {
+                continue;
+            }
+
             bool is_part_of_event = false;
             for (const auto & [start, end] : concurrent_node_ranges) {
                 if (root_node_idx >= start && root_node_idx <= end) {