]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cuda : enable CUDA Graph on CUDA Toolkit < 12.x (llama/12394)
authorGaurav Garg <redacted>
Mon, 17 Mar 2025 18:25:13 +0000 (23:55 +0530)
committerGeorgi Gerganov <redacted>
Thu, 27 Mar 2025 09:06:03 +0000 (11:06 +0200)
* Enable CUDA Graph on CTK < 12.x

`cudaGraphExecUpdate` API was changed on 12.x. For this reason CUDA graph support was disabled on older CUDA toolkit. This change enables CUDA support in CTK version < 12.x by using older API if CTK < 12.x.

* Fix compilation errors with MUSA

* Disable CUDA Graph for MUSA

ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/vendors/hip.h
ggml/src/ggml-cuda/vendors/musa.h
ggml/src/ggml-musa/CMakeLists.txt

index 4d4ac47c034e1d26b86a10ced3f24b7a169e99f5..e78205e5d53afd4a47ef77a49d8ad242217ef9bc 100644 (file)
@@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
 };
 
 
-#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
+#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
 #define USE_CUDA_GRAPH
 #endif
 
index 497de37be8210dfcc735bccf7d6b4b34e5c6b2f0..9bba398ce6be113c1c5a4b438e8393253cddefa9 100644 (file)
@@ -2610,13 +2610,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
 
 static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
 
+#if CUDART_VERSION >= 12000
     cudaGraphExecUpdateResultInfo result_info;
-#ifdef __HIP_PLATFORM_AMD__
-    hipGraphNode_t errorNode;
-    hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
-#else
     cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
-#endif
+#else
+    cudaGraphNode_t errorNode;
+    cudaGraphExecUpdateResult result_info;
+    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
+#endif // CUDART_VERSION >= 12000
+
     if (stat == cudaErrorGraphExecUpdateFailure) {
 #ifndef NDEBUG
         GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
index 81964611c6064c9f5d2c1e72409635cb1f54c7a8..aace21e3a8b18ff114ff4f5dceb59919580f677b 100644 (file)
 #define cudaGraphExecDestroy hipGraphExecDestroy
 #define cudaGraphLaunch hipGraphLaunch
 #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
-#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
+#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
 #define cudaGraphNodeType hipGraphNodeType
 #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
 #define cudaGraphInstantiate hipGraphInstantiate
index 6cc1b69ee3390c619c574c3c6e1a871051ba8e94..997f671431e01e16a61a1daf48f38b328350d6f8 100644 (file)
 #define cudaGraphExecDestroy musaGraphExecDestroy
 #define cudaGraphExec_t musaGraphExec_t
 #define cudaGraphExecUpdate musaGraphExecUpdate
-#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
+#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
 #define cudaGraphGetNodes musaGraphGetNodes
 #define cudaGraphInstantiate musaGraphInstantiate
 #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
 #define cudaGraph_t musaGraph_t
 #define cudaKernelNodeParams musaKernelNodeParams
 #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
+#define cudaStreamBeginCapture musaStreamBeginCapture
 #define cudaStreamEndCapture musaStreamEndCapture
 
 typedef mt_bfloat16 nv_bfloat16;
index 166970ca6bfb8eb3bb235cdf5c71d28db5a65a08..92f05d5558c80bff0c1cf9f5bf272fee0e2ef80e 100644 (file)
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
     add_compile_definitions(GGML_USE_MUSA)
     add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
 
-    if (GGML_CUDA_GRAPHS)
-        add_compile_definitions(GGML_CUDA_USE_GRAPHS)
-    endif()
-
     if (GGML_CUDA_FORCE_MMQ)
         add_compile_definitions(GGML_CUDA_FORCE_MMQ)
     endif()