]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : remove legacy copy-op pointer indirection code (llama/16485)
authorAnav Prasad <redacted>
Tue, 14 Oct 2025 09:53:49 +0000 (09:53 +0000)
committerGeorgi Gerganov <redacted>
Tue, 14 Oct 2025 19:07:44 +0000 (22:07 +0300)
* remove legacy copy-op pointer indirection code

* further removal of copy-op indirection code

* renamed check_node_graph_compatibility_and_refresh_copy_ops function

src/ggml-cuda/common.cuh
src/ggml-cuda/cpy.cu
src/ggml-cuda/cpy.cuh
src/ggml-cuda/ggml-cuda.cu

index e0abde5427c832852f1803b70f1167986af4f05e..41ff89c4d69227c3030b752610d6fe89ab39e118 100644 (file)
@@ -944,13 +944,6 @@ struct ggml_cuda_graph {
     bool disable_due_to_failed_graph_capture = false;
     int number_consecutive_updates = 0;
     std::vector<ggml_graph_node_properties> ggml_graph_properties;
-    bool use_cpy_indirection = false;
-    std::vector<char *> cpy_dest_ptrs;
-    char ** dest_ptrs_d;
-    int dest_ptrs_size = 0;
-    // Index to allow each cpy kernel to be aware of it's position within the graph
-    // relative to other cpy nodes.
-    int graph_cpynode_index = -1;
 #endif
 };
 
index 746f43966b84c244d3824b400007083310f89810..12d5bf7763c38d60de2327469ee2a19834cdb08f 100644 (file)
@@ -8,18 +8,16 @@
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
 
 template <cpy_kernel_t cpy_1>
-static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
+static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
                                const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
                                const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                               const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
+                               const int nb12, const int nb13) {
     const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
         return;
     }
 
-    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
-
     // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
     // then combine those indices with the corresponding byte offsets to get the total offsets
     const int64_t i03 = i/(ne00 * ne01 * ne02);
@@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
 }
 
 template <cpy_kernel_t cpy_blck, int qk>
-static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
+                                 const int nb12, const int nb13) {
     const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
 
     if (i >= ne) {
         return;
     }
 
-    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
-
     const int i03 = i/(ne00 * ne01 * ne02);
     const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
     const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
@@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int
 }
 
 template <cpy_kernel_t cpy_blck, int qk>
-static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
+static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                 const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
+                                 const int nb12, const int nb13) {
     const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
 
     if (i >= ne) {
         return;
     }
 
-    char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
-
     const int i03 = i/(ne00 * ne01 * ne02);
     const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
     const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
@@ -118,67 +112,47 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
     cpy_blck(cx + x_offset, cdst + dst_offset);
 }
 
-// Copy destination pointers to GPU to be available when pointer indirection is in use
-
-void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
-#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
-    if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
-        CUDA_CHECK(cudaStreamSynchronize(stream));
-        if (cuda_graph->dest_ptrs_d != nullptr) {
-            CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
-        }
-        CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
-        cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
-    }
-    // copy destination pointers to GPU
-    CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
-    cuda_graph->graph_cpynode_index = 0; // reset index
-#else
-    GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream);
-#endif
-}
-
 template<typename src_t, typename dst_t>
 static void ggml_cpy_flt_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q8_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK8_0 == 0);
     const int num_blocks = ne / QK8_0;
     cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q8_0_f32_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     const int num_blocks = ne;
     cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q4_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK4_0 == 0);
     const int num_blocks = ne / QK4_0;
     cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q4_0_f32_cuda(
@@ -187,22 +161,22 @@ static void ggml_cpy_q4_0_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    cudaStream_t stream) {
     const int num_blocks = ne;
     cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-         ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q4_1_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK4_1 == 0);
     const int num_blocks = ne / QK4_1;
     cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q4_1_f32_cuda(
@@ -211,22 +185,22 @@ static void ggml_cpy_q4_1_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    cudaStream_t stream) {
     const int num_blocks = ne;
     cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-         ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q5_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK5_0 == 0);
     const int num_blocks = ne / QK5_0;
     cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q5_0_f32_cuda(
@@ -235,22 +209,22 @@ static void ggml_cpy_q5_0_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    cudaStream_t stream) {
     const int num_blocks = ne;
     cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-        ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q5_1_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK5_1 == 0);
     const int num_blocks = ne / QK5_1;
     cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q5_1_f32_cuda(
@@ -259,25 +233,25 @@ static void ggml_cpy_q5_1_f32_cuda(
     const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12,
     const int nb10, const int nb11, const int nb12, const int nb13,
-    cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    cudaStream_t stream) {
     const int num_blocks = ne;
     cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
-        ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_iq4_nl_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK4_NL == 0);
     const int num_blocks = ne / QK4_NL;
     cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
-void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
     const int64_t ne = ggml_nelements(src0);
     GGML_ASSERT(ne == ggml_nelements(src1));
 
@@ -311,16 +285,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
     char * src0_ddc = (char *) src0->data;
     char * src1_ddc = (char *) src1->data;
 
-    char ** dest_ptrs_d = nullptr;
-    int graph_cpynode_index = -1;
-#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
-    if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
-        dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
-        graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
-    }
-#else
-    GGML_UNUSED(disable_indirection_for_this_node);
-#endif
     if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
         GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
 #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
@@ -329,134 +293,62 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
         } else
 #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
         {
-            if (src0->type == GGML_TYPE_F32) {
-                ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
-            } else {
-                CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
-            }
+            CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
         }
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
-        ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
-        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
-        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
-        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
-        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
-        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
-        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
-        ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
-        ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
-        ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+        ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else {
         GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
     }
-#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
-    if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
-        ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
-    }
-#else
-    GGML_UNUSED(disable_indirection_for_this_node);
-#endif
-
 }
 
 void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    bool disable_indirection = true;
-    ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
-}
-
-void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
-    if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
-        // Prioritize CUDA graph compatibility over direct memory copy optimization.
-        // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
-        if (src0->type == GGML_TYPE_F32) {
-            return (void*) cpy_flt<cpy_1_flt<float, float>>;
-        } else {
-            return nullptr;
-        }
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        return (void*) cpy_flt<cpy_1_flt<float, float>>;
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
-        return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        return (void*) cpy_flt<cpy_1_flt<float, half>>;
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
-        return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
-    } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
-        return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
-        return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
-    } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
-        return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
-        return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
-    } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
-        return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
-        return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
-    } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
-        return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
-        return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
-        return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
-    } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
-        return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
-    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        return (void*) cpy_flt<cpy_1_flt<half, half>>;
-    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
-        return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
-    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
-        return (void*) cpy_flt<cpy_1_flt<half, float>>;
-    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
-        return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
-    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
-        return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
-    } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
-        return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
-        return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
-    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
-        return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
-    } else {
-        GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
-                ggml_type_name(src0->type), ggml_type_name(src1->type));
-    }
+    ggml_cuda_cpy(ctx, src0, dst);
 }
index 0bd3c0c6f8c277e00e54b95377307d0ea43a9647..a7a87d8fcfb7eb2a249d5a8ba3b7d53e0c113936 100644 (file)
@@ -2,10 +2,6 @@
 
 #define CUDA_CPY_BLOCK_SIZE 64
 
-void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1,  bool disable_indirection = false);
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
 
 void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
-
-void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
-
-void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
index 856e9de2e1115c339e38ae58aee8d2b947ff5b90..83b82c1ad3e5c920ae348b7af476e17b01810b39 100644 (file)
@@ -2633,11 +2633,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
 }
 
 #ifdef USE_CUDA_GRAPH
-static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
+static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
     bool use_cuda_graph) {
 
     // Loop over nodes in GGML graph to obtain info needed for CUDA graph
-    cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
 
     const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
     const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
@@ -2688,33 +2687,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
 #endif
         }
 
-        if (node->op == GGML_OP_CPY) {
-
-            // Store the pointers which are updated for each token, such that these can be sent
-            // to the device and accessed using indirection from CUDA graph
-            cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
-
-            // store a pointer to each copy op CUDA kernel to identify it later
-            void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
-            if (!ptr) {
-                use_cuda_graph = false;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
-#endif
-            }
-        }
-
         if (!use_cuda_graph) {
             break;
         }
     }
 
-    if (use_cuda_graph) {
-        cuda_ctx->cuda_graph->use_cpy_indirection = true;
-        // copy pointers to GPU so they can be accessed via indirection within CUDA graph
-        ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
-    }
-
     return use_cuda_graph;
 }
 
@@ -2733,7 +2710,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
 
 static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
     if (node->data != graph_node_properties->node_address &&
-          node->op != GGML_OP_CPY &&
           node->op != GGML_OP_VIEW) {
         return false;
     }
@@ -2754,7 +2730,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
     for (int i = 0; i < GGML_MAX_SRC; i++) {
         if (node->src[i] &&
             node->src[i]->data != graph_node_properties->src_address[i] &&
-            node->op != GGML_OP_CPY &&
             node->op != GGML_OP_VIEW
         ) {
             return false;
@@ -3120,7 +3095,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
     if (use_cuda_graph) {
         cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
 
-        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
+        use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
 
         // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
         if (use_cuda_graph && cuda_graph_update_required) {
@@ -3147,10 +3122,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
     }
 
-    if (!use_cuda_graph) {
-        cuda_ctx->cuda_graph->use_cpy_indirection = false;
-    }
-
 #else
     bool use_cuda_graph = false;
     bool cuda_graph_update_required = false;