]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cuda: cap grid.y at 65535 in non-contiguous dequantize/convert kernels (llama/19999)
authoroobabooga <redacted>
Sun, 1 Mar 2026 05:40:22 +0000 (02:40 -0300)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
ggml/src/ggml-cuda/convert.cu

index 09b6d5db6a023e167768362b90694f857dc0564b..b70492c7d6cf7eb7c24a8b351600da2e1fafb3c7 100644 (file)
@@ -16,27 +16,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
         return;
     }
 
-    const int64_t i01 = blockIdx.y;
-
-    for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
-        const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
-        const int64_t i02 = dm.y;
-        const int64_t i03 = dm.x;
-
-        const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
-
-        const int64_t ib = ibx0 + i00/qk; // block index
-        const int64_t iqs = (i00%qk)/qr; // quant index
-        const int64_t iybs = i00 - i00%qk; // y block start index
-        const int64_t y_offset = qr == 1 ? 1 : qk/2;
-
-        // dequantize
-        float2 v;
-        dequantize_kernel(vx, ib, iqs, v);
-
-        const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
-        y[iy0 + 0]        = ggml_cuda_cast<dst_t>(v.x);
-        y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
+    for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
+        for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
+            const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
+            const int64_t i02 = dm.y;
+            const int64_t i03 = dm.x;
+
+            const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
+
+            const int64_t ib = ibx0 + i00/qk; // block index
+            const int64_t iqs = (i00%qk)/qr; // quant index
+            const int64_t iybs = i00 - i00%qk; // y block start index
+            const int64_t y_offset = qr == 1 ? 1 : qk/2;
+
+            // dequantize
+            float2 v;
+            dequantize_kernel(vx, ib, iqs, v);
+
+            const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
+            y[iy0 + 0]        = ggml_cuda_cast<dst_t>(v.x);
+            y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
+        }
     }
 }
 
@@ -492,7 +492,7 @@ static void dequantize_block_cuda(const void * vx, dst_t * y,
         const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
     const int64_t ne0203 = ne02*ne03;
     const uint3 ne02_fdv = init_fastdiv_values(ne02);
-    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
+    const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
     dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
         (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
 }
@@ -628,18 +628,18 @@ static __global__ void convert_unary(
         return;
     }
 
-    const int64_t i01 = blockIdx.y;
-
     const src_t * x = (const src_t *) vx;
 
-    for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
-        const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
-        const int64_t i02 = dm.y;
-        const int64_t i03 = dm.x;
+    for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) {
+        for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
+            const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
+            const int64_t i02 = dm.y;
+            const int64_t i03 = dm.x;
 
-        const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
-        const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
-        y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
+            const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
+            const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
+            y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
+        }
     }
 }
 
@@ -649,7 +649,7 @@ static void convert_unary_cuda(const void * vx, dst_t * y,
         const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
     const int64_t ne0203 = ne02*ne03;
     const uint3 ne02_fdv = init_fastdiv_values(ne02);
-    const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
+    const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535));
     convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
         (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
 }