]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
improve CUDA cpy memory bandwidth when copying transposed tensor (llama/16841)
authorbssrdf <redacted>
Wed, 5 Nov 2025 20:55:04 +0000 (15:55 -0500)
committerGeorgi Gerganov <redacted>
Sun, 9 Nov 2025 16:30:22 +0000 (18:30 +0200)
* WIP

* added a cpy kernel specific to transposed tensor which uses smem to avoid uncoalesced access; test cases also added shwoing improved memory bandwidth

* added BF16 support

* more strict check to make sure src0 is a transpose

* reformulated to handle more complicated transpose cases

* bring back 2D transpose for higher performance

* allow build on windows

* tranpose copy more shapes

* minor tweak

* final clean up

* restore some test cases

* keep only the kernel for true tranposed case; updated with review suggestions

* make CI happy

* remove headers not needed

* reduced bank conflicts for fp16 and bf16

* add missing const*

* now bank conflicts free

* use padding instead of swizzling

---------

Co-authored-by: bssrdf <redacted>
src/ggml-cuda/cpy.cu
tests/test-backend-ops.cpp

index c5821acbdeb8a426906475589e3a0c3581a295c1..1dba60eb143ef13ef4b9db8a9c8a109781cec1e7 100644 (file)
@@ -7,6 +7,10 @@
 
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
 
+const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks
+const int CUDA_CPY_BLOCK_NM = 8;     // block size of 3rd dimension if available
+const int CUDA_CPY_BLOCK_ROWS = 8;   // block dimension for marching through rows
+
 template <cpy_kernel_t cpy_1>
 static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
                                const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -35,6 +39,55 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
     cpy_1(cx + x_offset, cdst + dst_offset);
 }
 
+template <typename T>
+static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
+                               const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                               const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                               const int nb12, const int nb13) {
+
+    const T* src = reinterpret_cast<const T*>(cx);
+    T* dst = reinterpret_cast<T*>(cdst);
+
+    const int64_t nmat = ne / (ne00 * ne01);
+    const int64_t n = ne00 * ne01;
+
+    const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
+    const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
+    const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x;  // transpose block offset
+    const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
+
+    __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
+
+#pragma unroll
+    for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
+
+        const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
+        if (imat >= nmat)
+            break;
+
+#pragma unroll
+        for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
+            if(x < ne01 && y + j < ne00){
+                const int row = threadIdx.y+j;
+                const int col = threadIdx.x * sizeof(float)/sizeof(T);
+                T *tile2 = reinterpret_cast<T*>(tile[row]);
+                tile2[col] = src[imat*n + (y+j)*ne01 + x];
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
+            if (ty + j < ne01 && tx < ne00) {
+                const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
+                const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]);
+                dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
+            }
+        }
+    }
+}
+
 static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
     float * cdstf = (float *)(cdsti);
 
@@ -136,15 +189,38 @@ cudaStream_t stream) {
         (cx, cdst, ne);
 }
 
-template<typename src_t, typename dst_t>
+template<typename src_t, typename dst_t, bool transposed = false>
 static void ggml_cpy_flt_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
-    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-    cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+    if (transposed) {
+        GGML_ASSERT(ne == ne00*ne01*ne02);  // ne[3] is 1 assumed
+        int ne00n, ne01n, ne02n;
+        if (nb00 < nb02) {
+            ne00n = ne00;
+            ne01n = ne01;
+            ne02n = ne02;
+        } else if (nb00 > nb02) {
+            ne00n = ne00;
+            ne01n = ne01*ne02;
+            ne02n = 1;
+        } else {
+            GGML_ASSERT(false);
+        }
+
+        dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
+                      (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
+                      (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
+        dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
+        cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
+            (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+    } else {
+        const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+        cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+            (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+    }
 }
 
 static void ggml_cpy_f32_q8_0_cuda(
@@ -310,6 +386,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
     char * src1_ddc = (char *) src1->data;
 
     const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
+    const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
 
     if (src0->type == src1->type && contiguous_srcs) {
         GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
@@ -322,7 +399,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
             CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
         }
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_flt_cuda<float, float>           (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        if (can_be_transposed) {
+            ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        } else {
+            ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        }
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
         if (contiguous_srcs) {
             ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
@@ -361,7 +442,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
     } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_flt_cuda<half, half>               (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        if (can_be_transposed) {
+            ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        } else {
+            ggml_cpy_flt_cuda<half, half>       (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        }
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
         if (contiguous_srcs) {
             ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16>  (src0_ddc, src1_ddc, ne, main_stream);
@@ -375,7 +460,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
             ggml_cpy_flt_cuda<half, float>          (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
-        ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        if (can_be_transposed) {
+            ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        } else {
+            ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        }
     } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
         if (contiguous_srcs) {
             ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half>  (src0_ddc, src1_ddc, ne, main_stream);
index 967a53c63d86d232e18f29dbf9bc51ee937d836a..f575420279a37bc60030773437f8e31feebf99c8 100644 (file)
@@ -2576,9 +2576,10 @@ struct test_cpy : public test_case {
     const std::array<int64_t, 4> permute_dst;
     bool _src_use_permute;
     bool _dst_use_permute;
+    bool _src_transpose;
 
     std::string vars() override {
-        return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
+        return VARS_TO_STR6(type_src, type_dst, ne, permute_src, permute_dst, _src_transpose);
     }
 
     double max_nmse_err() override {
@@ -2616,10 +2617,12 @@ struct test_cpy : public test_case {
     test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 10, 1},
             std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
-            std::array<int64_t, 4> permute_dst = {0, 0, 0, 0})
+            std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},
+            bool transpose_src = false)
         : type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
           _src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
-          _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
+          _dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),
+          _src_transpose(transpose_src){}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
@@ -2631,6 +2634,11 @@ struct test_cpy : public test_case {
             ggml_set_name(src, "src_permuted");
         }
 
+        if (_src_transpose) {
+            src = ggml_transpose(ctx, src);
+            ggml_set_name(src, "src_transposed");
+        }
+
         ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
         ggml_set_name(dst, "dst");
 
@@ -6641,6 +6649,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3}));
     test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}));
     test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 3}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
 
     test_cases.emplace_back(new test_cont());
     test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
@@ -7385,6 +7400,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_cpy(GGML_TYPE_F32,  GGML_TYPE_Q4_0, {8192, 512, 2, 1}));
     test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32,  {8192, 512, 2, 1}));
 
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
+
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
+
+
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));