]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Simplify and improve CUDA graphs through use of indirect copy pointers (llama/9017)
authorAlan Gray <redacted>
Thu, 3 Apr 2025 01:31:15 +0000 (02:31 +0100)
committerGeorgi Gerganov <redacted>
Thu, 24 Apr 2025 17:39:16 +0000 (20:39 +0300)
* CUDA: Simplify and improve CUDA graphs through use of indirect copy pointers

Previously there was complexity in the CUDA graphs implementation due
frequently changing parameters to copy kernels associated with K and V
cache pointers. This patch simplifies by using indirection to avoid
such parameters frequently changing, avoiding the need for frequent
graph updates.

Fixes #12152

* Addressed comments

* fix HIP builds

* properly sync to stream

* removed ggml_cuda_cpy_fn_ptrs

* move stream sync before free

* guard to only use indirection with graphs

* style fixes

* check for errors

---------

Co-authored-by: slaren <redacted>
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/cpy.cu
ggml/src/ggml-cuda/cpy.cuh
ggml/src/ggml-cuda/ggml-cuda.cu

index a718b6a1288032564381bd11533203171919bf81..8284a0017d272a791cc87778f0e19b4c006c44cf 100644 (file)
@@ -729,7 +729,13 @@ struct ggml_cuda_graph {
     bool disable_due_to_failed_graph_capture = false;
     int number_consecutive_updates = 0;
     std::vector<ggml_graph_node_properties> ggml_graph_properties;
-    std::vector<char **> updated_kernel_arg;
+    bool use_cpy_indirection = false;
+    std::vector<char *> cpy_dest_ptrs;
+    char ** dest_ptrs_d;
+    int dest_ptrs_size = 0;
+    // Index to allow each cpy kernel to be aware of it's position within the graph
+    // relative to other cpy nodes.
+    int graph_cpynode_index = -1;
 #endif
 };
 
index cca2bee0b279195a1a7aec857c3fca71297f65bf..1f3151e7566bdffc632dd9b2e442cc133a8b312e 100644 (file)
@@ -32,16 +32,18 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
 }
 
 template <cpy_kernel_t cpy_1>
-static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
                                    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
                                    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                   const int nb12, const int nb13) {
+                                   const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
     const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
         return;
     }
 
+    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
+
     // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
     // then combine those indices with the corresponding byte offsets to get the total offsets
     const int64_t i03 = i/(ne00 * ne01 * ne02);
@@ -288,16 +290,18 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
 }
 
 template <cpy_kernel_t cpy_blck, int qk>
-static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
+static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                 const int nb12, const int nb13) {
+                                 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
     const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
 
     if (i >= ne) {
         return;
     }
 
+    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
+
     const int i03 = i/(ne00 * ne01 * ne02);
     const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
     const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
@@ -314,16 +318,18 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
 }
 
 template <cpy_kernel_t cpy_blck, int qk>
-static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
+static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                 const int nb12, const int nb13) {
+                                 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
     const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
 
     if (i >= ne) {
         return;
     }
 
+    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
+
     const int i03 = i/(ne00 * ne01 * ne02);
     const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
     const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
@@ -339,66 +345,84 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
     cpy_blck(cx + x_offset, cdst + dst_offset);
 }
 
+// Copy destination pointers to GPU to be available when pointer indirection is in use
+
+void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
+#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+    if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
+        CUDA_CHECK(cudaStreamSynchronize(stream));
+        if (cuda_graph->dest_ptrs_d != nullptr) {
+            CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
+        }
+        CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
+        cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
+    }
+    // copy destination pointers to GPU
+    CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
+    cuda_graph->graph_cpynode_index = 0; // reset index
+#endif
+}
+
 static void ggml_cpy_f16_f32_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_f32_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_f16_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_q8_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK8_0 == 0);
     const int num_blocks = ne / QK8_0;
     cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_q8_0_f32_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = ne;
     cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_q4_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK4_0 == 0);
     const int num_blocks = ne / QK4_0;
     cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_q4_0_f32_cuda(
@@ -407,22 +431,22 @@ static void ggml_cpy_q4_0_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream) {
+    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
     const int num_blocks = ne;
     cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+         ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_q4_1_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK4_1 == 0);
     const int num_blocks = ne / QK4_1;
     cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_q4_1_f32_cuda(
@@ -431,22 +455,22 @@ static void ggml_cpy_q4_1_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream) {
+    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
     const int num_blocks = ne;
     cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+         ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_q5_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK5_0 == 0);
     const int num_blocks = ne / QK5_0;
     cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_q5_0_f32_cuda(
@@ -455,22 +479,22 @@ static void ggml_cpy_q5_0_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream) {
+    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
     const int num_blocks = ne;
     cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_q5_1_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK5_1 == 0);
     const int num_blocks = ne / QK5_1;
     cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_q5_1_f32_cuda(
@@ -479,32 +503,32 @@ static void ggml_cpy_q5_1_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream) {
+    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
     const int num_blocks = ne;
     cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f32_iq4_nl_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     GGML_ASSERT(ne % QK4_NL == 0);
     const int num_blocks = ne / QK4_NL;
     cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 static void ggml_cpy_f16_f16_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
 void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
@@ -541,46 +565,60 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
     char * src0_ddc = (char *) src0->data;
     char * src1_ddc = (char *) src1->data;
 
+    char ** dest_ptrs_d = nullptr;
+    int graph_cpynode_index = -1;
+#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+    if(ctx.cuda_graph->use_cpy_indirection) {
+        dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
+        graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
+    }
+#endif
     if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
         GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
         CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
-        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
-        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
-        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
-        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
-        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
-        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else {
         GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
     }
+#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
+    if(ctx.cuda_graph->use_cpy_indirection) {
+        ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
+    }
+#endif
+
 }
 
 void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
index 28b06cddaa87b98316ae3a94fe3ff793812d398b..6bed0564df27af508b16dea1f9b9816449e9c8db 100644 (file)
@@ -7,3 +7,5 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
 void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
+
+void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
index 060550c57208d915bdb04ad09834c0056bca7d80..78717df1a6095165aef17e5f49ac98f4916728ae 100644 (file)
@@ -2469,10 +2469,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
 
 #ifdef USE_CUDA_GRAPH
 static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
-    std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
+    bool use_cuda_graph) {
 
     // Loop over nodes in GGML graph to obtain info needed for CUDA graph
-    cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+    cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
+
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
@@ -2504,8 +2505,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
         }
 
         if (node->op == GGML_OP_CPY) {
-            // store the copy op parameter which changes with each token.
-            cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+
+            // Store the pointers which are updated for each token, such that these can be sent
+            // to the device and accessed using indirection from CUDA graph
+            cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
+
             // store a pointer to each copy op CUDA kernel to identify it later
             void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
             if (!ptr) {
@@ -2513,10 +2517,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
 #ifndef NDEBUG
                 GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
 #endif
-            } else {
-                if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
-                    ggml_cuda_cpy_fn_ptrs.push_back(ptr);
-                }
             }
         }
 
@@ -2525,6 +2525,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
         }
     }
 
+    if (use_cuda_graph) {
+        cuda_ctx->cuda_graph->use_cpy_indirection = true;
+        // copy pointers to GPU so they can be accessed via indirection within CUDA graph
+        ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
+    }
+
     return use_cuda_graph;
 }
 
@@ -2579,51 +2585,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
     return true;
 }
 
-static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
-
-    if (cuda_graph_update_required) {
-        // Extract nodes from graph
-        // First call with null argument gets number of nodes in graph
-        CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
-        // Subsequent call with non-null argument gets nodes
-        cuda_ctx->cuda_graph->nodes.clear();
-        cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
-        cuda_ctx->cuda_graph->params.clear();
-        cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
-        if (cuda_ctx->cuda_graph->num_nodes > 0) {
-            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
-
-            // Loop over nodes, and extract kernel parameters from each node
-            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                cudaGraphNodeType node_type;
-                CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
-                if (node_type == cudaGraphNodeTypeKernel) {
-                    cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
-                    if (stat == cudaErrorInvalidDeviceFunction) {
-                        // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
-                        // We don't need to update blas nodes, so clear error and move on.
-                        (void)cudaGetLastError();
-                    } else {
-                        GGML_ASSERT(stat == cudaSuccess);
-                    }
-                }
-            }
-        }
-    } else {
-        // One of the arguments to the copy kernel is updated for each token, hence we need to
-        // replace that argument with the updated value in the CUDA graph
-        // on update steps, the live parameters will already be captured
-        int k = 0;
-        for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-            if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
-                char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
-                *(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr;
-                CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
-            }
-        }
-    }
-}
-
 static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
 
     bool cuda_graph_update_required = false;
@@ -2683,8 +2644,7 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
 #endif
 
 static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
-   [[maybe_unused]] std::vector<void *> & ggml_cuda_cpy_fn_ptrs,  bool & graph_evaluated_or_captured, bool & use_cuda_graph,
-    bool & cuda_graph_update_required) {
+    bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
 
     while (!graph_evaluated_or_captured) {
         // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
@@ -2734,13 +2694,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
         if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
             CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
         }
-
-        // Perform update to graph (if required for this token), and change copy parameter (required for every token)
-        maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
-
-        // Update graph executable
-        update_cuda_graph_executable(cuda_ctx);
-
+        if (cuda_graph_update_required) { // Update graph executable
+            update_cuda_graph_executable(cuda_ctx);
+        }
         // Launch graph
         CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
 #else
@@ -2754,10 +2710,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
 
     ggml_cuda_set_device(cuda_ctx->device);
 
-    // vector of pointers to CUDA cpy kernels, which are required to identify
-    // kernel parameters which need updated in the graph for each token
-    std::vector<void *> ggml_cuda_cpy_fn_ptrs;
-
 #ifdef USE_CUDA_GRAPH
     static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
 
@@ -2791,8 +2743,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
     if (use_cuda_graph) {
         cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
 
-        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
-                             ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
+        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
 
         // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
         if (use_cuda_graph && cuda_graph_update_required) {
@@ -2813,6 +2764,10 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
     }
 
+    if (!use_cuda_graph) {
+        cuda_ctx->cuda_graph->use_cpy_indirection = false;
+    }
+
 #else
     bool use_cuda_graph = false;
     bool cuda_graph_update_required = false;
@@ -2820,7 +2775,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
 
     bool graph_evaluated_or_captured = false;
 
-    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
+    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
 
     return GGML_STATUS_SUCCESS;
 }