};
-#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
+#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
#define USE_CUDA_GRAPH
#endif
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
+#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info;
-#ifdef __HIP_PLATFORM_AMD__
- hipGraphNode_t errorNode;
- hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
-#else
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
-#endif
+#else
+ cudaGraphNode_t errorNode;
+ cudaGraphExecUpdateResult result_info;
+ cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
+#endif // CUDART_VERSION >= 12000
+
if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
#define cudaGraphExecDestroy hipGraphExecDestroy
#define cudaGraphLaunch hipGraphLaunch
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
-#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
+#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
#define cudaGraphNodeType hipGraphNodeType
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
#define cudaGraphInstantiate hipGraphInstantiate
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
-#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
+#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
+#define cudaStreamBeginCapture musaStreamBeginCapture
#define cudaStreamEndCapture musaStreamEndCapture
typedef mt_bfloat16 nv_bfloat16;
add_compile_definitions(GGML_USE_MUSA)
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
- if (GGML_CUDA_GRAPHS)
- add_compile_definitions(GGML_CUDA_USE_GRAPHS)
- endif()
-
if (GGML_CUDA_FORCE_MMQ)
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
endif()