]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: Add `fastdiv` to `k_bin_bcast*`, giving 1-3% E2E performance (llama/15872)
authorOliver Simons <redacted>
Wed, 10 Sep 2025 20:04:03 +0000 (22:04 +0200)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* Add fastdiv and fastmodulo to k_bin_bcast kernel

* Address review comments

* `prod_` instead of `prod` suffix

* Add test case for `k_bin_bcast_unravel` in CUDA backend

src/ggml-cuda/binbcast.cu
tests/test-backend-ops.cpp

index 1c76566344a884493ba77fe6a0e7b0697b3bed32..725e1a81a1fc708f4b24b369e7b756635b2be685 100644 (file)
@@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
     return a / b;
 }
 
-
-
-template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
-static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
-        const int ne0, const int ne1, const int ne2, const int ne3,
-        const int ne10, const int ne11, const int ne12, const int ne13,
-        /*int s0, */ const int s1, const int s2, const int s3,
-        /*int s00,*/ const int s01, const int s02, const int s03,
-        /*int s10,*/ const int s11, const int s12, const int s13,
-        src1_ptrs... src1s) {
-    const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
-    const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
-    const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
-    const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
-
-    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+template <float (*bin_op)(const float, const float),
+          typename src0_t,
+          typename src1_t,
+          typename dst_t,
+          typename... src1_ptrs>
+static __global__ void k_bin_bcast(const src0_t *         src0,
+                                   const src1_t *         src1,
+                                   dst_t *                dst,
+                                   const int              ne0,
+                                   const int              ne1,
+                                   const int              ne2,
+                                   const uint3            ne3,
+                                   const uint3            ne10,
+                                   const uint3            ne11,
+                                   const uint3            ne12,
+                                   const uint3            ne13,
+                                   /*int s0, */ const int s1,
+                                   const int              s2,
+                                   const int              s3,
+                                   /*int s00,*/ const int s01,
+                                   const int              s02,
+                                   const int              s03,
+                                   /*int s10,*/ const int s11,
+                                   const int              s12,
+                                   const int              s13,
+                                   src1_ptrs... src1s) {
+    const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
+    const uint32_t i1  = (blockDim.y * blockIdx.y + threadIdx.y);
+    const uint32_t i2  = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
+    const uint32_t i3  = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
+
+    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
         return;
     }
 
-    const int i11 = i1 % ne11;
-    const int i12 = i2 % ne12;
-    const int i13 = i3 % ne13;
+    const uint32_t i11 = fastmodulo(i1, ne11);
+    const uint32_t i12 = fastmodulo(i2, ne12);
+    const uint32_t i13 = fastmodulo(i3, ne13);
 
     const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
     const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
     dst_t * dst_row = dst + i_dst;
 
-    for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
-        const int i10 = i0 % ne10;
+    for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
+        const uint32_t i10 = fastmodulo(i0, ne10);
 
         float result = src0_row ? (float) src0_row[i0] : 0.0f;
         if constexpr (sizeof...(src1_ptrs) > 0) {
@@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
     }
 }
 
-template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
-static __global__ void k_bin_bcast_unravel(const src0_t *   src0, const src1_t *   src1, dst_t *          dst,
-        const int ne0, const int ne1, const int ne2,const int ne3,
-        const int ne10, const int ne11, const int ne12, const int ne13,
-        /*int s0, */ const int s1, const int s2, const int s3,
-        /*int s00,*/ const int s01, const int s02, const int s03,
-        /*int s10,*/ const int s11, const int s12, const int s13,
-        src1_ptrs ... src1s) {
+template <float (*bin_op)(const float, const float),
+          typename src0_t,
+          typename src1_t,
+          typename dst_t,
+          typename... src1_ptrs>
+static __global__ void k_bin_bcast_unravel(const src0_t *         src0,
+                                           const src1_t *         src1,
+                                           dst_t *                dst,
+                                           const uint3            ne0,
+                                           const uint3            ne1,
+                                           const uint3            ne2,
+                                           const uint32_t         ne3,
+                                           const uint3            prod_012,
+                                           const uint3            prod_01,
+                                           const uint3            ne10,
+                                           const uint3            ne11,
+                                           const uint3            ne12,
+                                           const uint3            ne13,
+                                           /*int s0, */ const int s1,
+                                           const int              s2,
+                                           const int              s3,
+                                           /*int s00,*/ const int s01,
+                                           const int              s02,
+                                           const int              s03,
+                                           /*int s10,*/ const int s11,
+                                           const int              s12,
+                                           const int              s13,
+                                           src1_ptrs... src1s) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
-    const int i3 = i/(ne2*ne1*ne0);
-    const int i2 = (i/(ne1*ne0)) % ne2;
-    const int i1 = (i/ne0) % ne1;
-    const int i0 = i % ne0;
+    const uint32_t i3 = fastdiv(i, prod_012);
+    const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
+    const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
+    const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
 
-    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+    if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
         return;
     }
 
-    const int i11 = i1 % ne11;
-    const int i12 = i2 % ne12;
-    const int i13 = i3 % ne13;
+    const int i11 = fastmodulo(i1, ne11);
+    const int i12 = fastmodulo(i2, ne12);
+    const int i13 = fastmodulo(i3, ne13);
 
     const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t *   src0, const src1_t *
     const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
     dst_t * dst_row = dst + i_dst;
 
-    const int i10 = i0 % ne10;
+    const int i10 = fastmodulo(i0, ne10);
 
     float result = src0_row ? (float) src0_row[i0] : 0.0f;
     if constexpr (sizeof...(src1_ptrs) > 0) {
@@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
         //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
         //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
 
-        int64_t ne10 = cne1[0];
-        int64_t ne11 = cne1[1];
-        int64_t ne12 = cne1[2];
-        int64_t ne13 = cne1[3];
-
         size_t nb0 = cnb[0];
         size_t nb1 = cnb[1];
         size_t nb2 = cnb[2];
@@ -233,48 +264,51 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
         block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
         block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
 
-        dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
-                        (ne1 + block_dims.y - 1) / block_dims.y,
+        dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
                         (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
 
+        const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
+        const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
+        const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
+        const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
+
         if (block_nums.z > 65535) {
-            int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+            int         block_num  = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+            const uint3 prod_012    = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
+            const uint3 prod_01     = init_fastdiv_values((uint32_t) (ne0 * ne1));
+            const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
+            const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
+            const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
+
             if constexpr (sizeof...(I) > 0) {
-                k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
-                    <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
-                        ne0, ne1, ne2, ne3,
-                        ne10, ne11, ne12, ne13,
-                        /* s0, */ s1, s2, s3,
-                        /* s00,*/ s01, s02, s03,
-                        /* s10,*/ s11, s12,s13,
-                        (const src1_t *) dst->src[I + 1]->data...);
+                k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
+                    src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
+                    ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s00,*/ s01, s02, s03,
+                    /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
             } else {
                 k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
-                    <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
-                        ne0, ne1, ne2, ne3,
-                        ne10, ne11, ne12, ne13,
-                        /* s0, */ s1, s2, s3,
-                        /* s00,*/ s01, s02, s03,
-                        /* s10,*/ s11, s12,s13);
+                    <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
+                                                           ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
+                                                           /* s0, */ s1, s2, s3,
+                                                           /* s00,*/ s01, s02, s03,
+                                                           /* s10,*/ s11, s12, s13);
             }
         } else {
+            const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
             if constexpr (sizeof...(I) > 0) {
-                k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
-                    <<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
-                        ne0, ne1, ne2, ne3,
-                        ne10, ne11, ne12, ne13,
-                        /* s0, */ s1, s2, s3,
-                        /* s00,*/ s01, s02, s03,
-                        /* s10,*/ s11, s12,s13,
-                        (const src1_t *) dst->src[I + 1]->data...);
+                k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
+                    src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s00,*/ s01, s02, s03,
+                    /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
             } else {
-                k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
-                    <<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
-                        ne0, ne1, ne2, ne3,
-                        ne10, ne11, ne12, ne13,
-                        /* s0, */ s1, s2, s3,
-                        /* s00,*/ s01, s02, s03,
-                        /* s10,*/ s11, s12,s13);
+                k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
+                    src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s00,*/ s01, s02, s03,
+                    /* s10,*/ s11, s12, s13);
             }
         }
     }
index 4a882ab072cc01051dc009e06c3a99685d9ef3a0..b54a1a4e823f97056b242deb9942f1022726bb1d 100644 (file)
@@ -6050,6 +6050,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
         add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
 
+        // test case for k_bin_bcast_unravel in CUDA backend
+        add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});
+
         // stable diffusion
         add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
         add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});