]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: faster non-contiguous concat (llama/10760)
authora3sh <redacted>
Thu, 12 Dec 2024 18:09:50 +0000 (02:09 +0800)
committerGeorgi Gerganov <redacted>
Wed, 18 Dec 2024 10:52:16 +0000 (12:52 +0200)
* faster uncontiguous concat

* Use a lambda to avoid code duplication

Co-authored-by: Diego Devesa <redacted>
* Update ggml/src/ggml-cuda/concat.cu

* add constexpr  and static assert

---------

Co-authored-by: Diego Devesa <redacted>
ggml/src/ggml-cuda/concat.cu

index dac10ec36b0bd02dfbf922bb45a79470d4fb7465..2f42b8a9538e2c5a65de96121672dfbcfcf7db78 100644 (file)
@@ -94,7 +94,9 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
 }
 
 // non-contiguous kernel (slow)
-static __global__ void concat_f32_non_cont(
+template <int dim>
+static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
+    concat_f32_non_cont(
         const char * src0,
         const char * src1,
               char * dst,
@@ -121,22 +123,28 @@ static __global__ void concat_f32_non_cont(
           uint64_t   nb0,
           uint64_t   nb1,
           uint64_t   nb2,
-          uint64_t   nb3,
-          int32_t   dim) {
+          uint64_t   nb3){
+    static_assert(dim >= 0 && dim <= 3);
+
     const int64_t i3 = blockIdx.z;
     const int64_t i2 = blockIdx.y;
     const int64_t i1 = blockIdx.x;
 
-    int64_t o[4] = {0, 0, 0, 0};
-    o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
-
     const float * x;
 
-    for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+    for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
         if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
             x = (const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);
         } else {
-            x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+            if constexpr (dim == 0) {
+                x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);
+            } else if constexpr (dim == 1) {
+                x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);
+            } else if constexpr (dim == 2) {
+                x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);
+            } else if constexpr (dim == 3) {
+                x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);
+            }
         }
 
         float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -182,15 +190,32 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         }
     } else {
         dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
-        concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
-                (const char *)src0->data,
-                (const char *)src1->data,
-                (      char *)dst->data,
+        auto launch_kernel = [&](auto dim) {
+            concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
+                (const char *) src0->data, (const char *) src1->data, (char *) dst->data,
                 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
                 src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
                 src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
                 src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
-                dst->ne[0],  dst->ne[1],  dst->ne[2],  dst->ne[3],
-                dst->nb[0],  dst->nb[1],  dst->nb[2],  dst->nb[3], dim);
+                dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+                dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
+        };
+        switch (dim) {
+            case 0:
+                launch_kernel(std::integral_constant<int, 0>{});
+                break;
+            case 1:
+                launch_kernel(std::integral_constant<int, 1>{});
+                break;
+            case 2:
+                launch_kernel(std::integral_constant<int, 2>{});
+                break;
+            case 3:
+                launch_kernel(std::integral_constant<int, 3>{});
+                break;
+            default:
+                GGML_ABORT("Invalid dim: %d", dim);
+                break;
+        }
     }
 }