]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: loop over ne2*ne3 in case it overflows (llama/19538)
authorAman Gupta <redacted>
Fri, 13 Feb 2026 11:31:40 +0000 (17:01 +0530)
committerGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 22:20:18 +0000 (00:20 +0200)
* CUDA: loop over ne2*ne3 in case it overflows

* use fastdiv

src/ggml-cuda/convert.cu

index ba3d4eeb88085492eae763e965d624d82e246ed9..09b6d5db6a023e167768362b90694f857dc0564b 100644 (file)
@@ -7,7 +7,8 @@
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
-        const int64_t ne00, const int64_t ne01, const int64_t ne02,
+        const int64_t ne00, const int64_t ne01,
+        const int64_t ne0203, const uint3 ne02,
         const int64_t s01, const int64_t s02, const int64_t s03) {
     const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
 
@@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
     }
 
     const int64_t i01 = blockIdx.y;
-    const int64_t i02 = blockIdx.z % ne02;
-    const int64_t i03 = blockIdx.z / ne02;
 
-    const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
+    for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
+        const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
+        const int64_t i02 = dm.y;
+        const int64_t i03 = dm.x;
 
-    const int64_t ib = ibx0 + i00/qk; // block index
-    const int64_t iqs = (i00%qk)/qr; // quant index
-    const int64_t iybs = i00 - i00%qk; // y block start index
-    const int64_t y_offset = qr == 1 ? 1 : qk/2;
+        const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
 
-    // dequantize
-    float2 v;
-    dequantize_kernel(vx, ib, iqs, v);
+        const int64_t ib = ibx0 + i00/qk; // block index
+        const int64_t iqs = (i00%qk)/qr; // quant index
+        const int64_t iybs = i00 - i00%qk; // y block start index
+        const int64_t y_offset = qr == 1 ? 1 : qk/2;
 
-    const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
-    y[iy0 + 0]        = ggml_cuda_cast<dst_t>(v.x);
-    y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
+        // dequantize
+        float2 v;
+        dequantize_kernel(vx, ib, iqs, v);
+
+        const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
+        y[iy0 + 0]        = ggml_cuda_cast<dst_t>(v.x);
+        y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
+    }
 }
 
 template <bool need_check>
@@ -485,9 +490,11 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 static void dequantize_block_cuda(const void * vx, dst_t * y,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
-    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
+    const int64_t ne0203 = ne02*ne03;
+    const uint3 ne02_fdv = init_fastdiv_values(ne02);
+    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
     dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
-        (vx, y, ne00, ne01, ne02, s01, s02, s03);
+        (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
 
 template <typename src_t, typename dst_t>
 static __global__ void convert_unary(
-        const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+        const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
+        const int64_t ne0203, const uint3 ne02,
         const int64_t s01, const int64_t s02, const int64_t s03) {
     const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -621,23 +629,29 @@ static __global__ void convert_unary(
     }
 
     const int64_t i01 = blockIdx.y;
-    const int64_t i02 = blockIdx.z % ne02;
-    const int64_t i03 = blockIdx.z / ne02;
 
     const src_t * x = (const src_t *) vx;
 
-    const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
-    const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
-    y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
+    for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
+        const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
+        const int64_t i02 = dm.y;
+        const int64_t i03 = dm.x;
+
+        const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
+        const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
+        y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
+    }
 }
 
 template <typename src_t, typename dst_t>
 static void convert_unary_cuda(const void * vx, dst_t * y,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
-    const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
+    const int64_t ne0203 = ne02*ne03;
+    const uint3 ne02_fdv = init_fastdiv_values(ne02);
+    const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
     convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
-        (vx, y, ne00, ne01, ne02, s01, s02, s03);
+        (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
 }
 
 template <typename src_t, typename dst_t>