]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: use fastdiv in set-rows (llama/16834)
authorAman Gupta <redacted>
Wed, 29 Oct 2025 13:11:53 +0000 (21:11 +0800)
committerGeorgi Gerganov <redacted>
Sat, 1 Nov 2025 07:41:35 +0000 (09:41 +0200)
* CUDA: use fastdiv in set-rows

* add assert about value fitting in u32

src/ggml-cuda/common.cuh
src/ggml-cuda/set-rows.cu

index 1af23588301dd4c6d5ccad352ebb7fa96d3be47e..6a472be7fbb664e72a9300ddf966de4488eb98c8 100644 (file)
@@ -625,8 +625,11 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
 // and a shift:
 //
 // n/d = (mulhi(n, mp) + n) >> L;
-static const uint3 init_fastdiv_values(uint32_t d) {
-    GGML_ASSERT(d != 0);
+static const uint3 init_fastdiv_values(uint64_t d_64) {
+    GGML_ASSERT(d_64 != 0);
+    GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
+
+    uint32_t d = (uint32_t)d_64;
 
     // compute L = ceil(log2(d));
     uint32_t L = 0;
index 1525a159527af36581c78933978972f6a3ccf0d2..631de7e8fa51a2ac6a31ec8bddb0318d28636566 100644 (file)
@@ -4,30 +4,53 @@
 typedef void (*set_rows_kernel_t)(const char * src, char * dst);
 
 // Generic quantized set_rows kernel template
-template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
-static __global__ void k_set_rows_quant(
-        const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
-        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
-        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
-        const int64_t s01, const int64_t s02, const int64_t s03,
-        const int64_t s10, const int64_t s11, const int64_t s12,
-        const int64_t s1, const int64_t s2, const int64_t s3) {
-
+template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
+static __global__ void k_set_rows_quant(const float * __restrict__ src0,
+                                        const idx_t * __restrict__ src1,
+                                        block_type * __restrict__ dst,
+                                        const int64_t ne_total,
+                                        const int64_t ne10,
+                                        const int64_t ne11,
+                                        const int64_t ne12,
+                                        const int64_t ne13,
+                                        const int64_t s01,
+                                        const int64_t s02,
+                                        const int64_t s03,
+                                        const int64_t s10,
+                                        const int64_t s11,
+                                        const int64_t s12,
+                                        const int64_t s1,
+                                        const int64_t s2,
+                                        const int64_t s3,
+                                        const uint3   ne00,
+                                        const uint3   ne01,
+                                        const uint3   ne02,
+                                        const uint3   ne11_fd,
+                                        const uint3   ne12_fd) {
     const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
-    const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
 
     if (i >= ne_total) {
         return;
     }
 
     const int64_t i_base = i * qk;
-    const int64_t i03 = i_base / (ne00 * ne01 * ne02);
-    const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
-    const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
-    const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
+    uint32_t      tmp    = (uint32_t) i_base;
+    uint2         div_mod;
+
+    div_mod           = fast_div_modulo(tmp, ne00);
+    const int64_t i00 = div_mod.y;
+    tmp               = div_mod.x;
 
-    const int64_t i12 = i03 % ne12;
-    const int64_t i11 = i02 % ne11;
+    div_mod           = fast_div_modulo(tmp, ne01);
+    const int64_t i01 = div_mod.y;
+    tmp               = div_mod.x;
+
+    div_mod           = fast_div_modulo(tmp, ne02);
+    const int64_t i02 = div_mod.y;
+    const int64_t i03 = div_mod.x;
+
+    const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
+    const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
     const int64_t i10 = i01;
 
     const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant(
     quantize_func(src_block, dst_block);
 
     GGML_UNUSED(ne10);
+    GGML_UNUSED(ne11);
+    GGML_UNUSED(ne12);
     GGML_UNUSED(ne13);
 }
 
@@ -71,40 +96,65 @@ static void set_rows_cuda_quant(
     const int64_t s2  = nb2;
     const int64_t s3  = nb3;
 
-    if (ne_total > 0) {
+    if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
+        const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
+        const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
+        const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
+        const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
+        const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
+
         k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
-            src0_d, src1_d, dst_d,
-            ne00, ne01, ne02, ne03,
-            ne10, ne11, ne12, ne13,
-            s01, s02, s03,
-            s10, s11, s12,
-            s1, s2, s3);
+            src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
+            ne01_fd, ne02_fd, ne11_fd, ne12_fd);
     }
 }
 
-template<typename src_t, typename idx_t, typename dst_t>
-static __global__ void k_set_rows(
-        const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
-        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
-        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
-        const int64_t s01, const int64_t s02, const int64_t s03,
-        const int64_t s10, const int64_t s11, const int64_t s12,
-        const int64_t s1, const int64_t s2, const int64_t s3) {
-
+template <typename src_t, typename idx_t, typename dst_t>
+static __global__ void k_set_rows(const src_t * __restrict__ src0,
+                                  const idx_t * __restrict__ src1,
+                                  dst_t * __restrict__ dst,
+                                  const int64_t ne_total,
+                                  const int64_t ne10,
+                                  const int64_t ne11,
+                                  const int64_t ne12,
+                                  const int64_t ne13,
+                                  const int64_t s01,
+                                  const int64_t s02,
+                                  const int64_t s03,
+                                  const int64_t s10,
+                                  const int64_t s11,
+                                  const int64_t s12,
+                                  const int64_t s1,
+                                  const int64_t s2,
+                                  const int64_t s3,
+                                  const uint3   ne00,
+                                  const uint3   ne01,
+                                  const uint3   ne02,
+                                  const uint3   ne11_fd,
+                                  const uint3   ne12_fd) {
     const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
-    const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
 
     if (i >= ne_total) {
         return;
     }
 
-    const int64_t i03 = i / (ne00 * ne01 * ne02);
-    const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
-    const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
-    const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
+    uint32_t tmp = (uint32_t) i;
+    uint2    div_mod;
+
+    div_mod           = fast_div_modulo(tmp, ne00);
+    const int64_t i00 = div_mod.y;
+    tmp               = div_mod.x;
 
-    const int64_t i12 = i03 % ne12;
-    const int64_t i11 = i02 % ne11;
+    div_mod           = fast_div_modulo(tmp, ne01);
+    const int64_t i01 = div_mod.y;
+    tmp               = div_mod.x;
+
+    div_mod           = fast_div_modulo(tmp, ne02);
+    const int64_t i02 = div_mod.y;
+    const int64_t i03 = div_mod.x;
+
+    const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
+    const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
     const int64_t i10 = i01;
 
     const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
@@ -115,6 +165,8 @@ static __global__ void k_set_rows(
     dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
 
     GGML_UNUSED(ne10);
+    GGML_UNUSED(ne11);
+    GGML_UNUSED(ne12);
     GGML_UNUSED(ne13);
 }
 
@@ -144,14 +196,16 @@ static void set_rows_cuda(
     const int64_t s2  = nb2/sizeof(dst_t);
     const int64_t s3  = nb3/sizeof(dst_t);
 
-    if (ne_total > 0) {
-        k_set_rows<<<grid_size, block_size, 0, stream>>>(
-            src0_d, src1_d, dst_d,
-            ne00, ne01, ne02, ne03,
-            ne10, ne11, ne12, ne13,
-            s01, s02, s03,
-            s10, s11, s12,
-            s1, s2, s3);
+    if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
+        const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
+        const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
+        const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
+        const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
+        const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
+
+        k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
+                                                         s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
+                                                         ne11_fd, ne12_fd);
     }
 }